// SPDX-License-Identifier: GPL-2.0 or Linux-OpenIB
/* Copyright (c) 2015 - 2020 Intel Corporation */
#include "osdep.h"
#include "status.h"
#include "hmc.h"
#include "defs.h"
#include "type.h"
#include "protos.h"
#include "virtchnl.h"
#include "ws.h"
#include "i40iw_hw.h"

/**
 * vchnl_vf_send_get_ver_req - Request Channel version
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 */
static enum irdma_status_code vchnl_vf_send_get_ver_req(struct irdma_sc_dev *dev,
							struct irdma_virtchnl_req *vchnl_req)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;

	if (!dev->vchnl_up)
		return ret_code;

	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_GET_VER;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_GET_VER_V1;
	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev,
					    0,
					    (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

/**
 * vchnl_vf_send_get_pe_stats_req - Request PE stats from VF
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 */
static enum irdma_status_code vchnl_vf_send_get_pe_stats_req(struct irdma_sc_dev *dev,
							     struct irdma_virtchnl_req *vchnl_req)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;

	if (!dev->vchnl_up)
		return ret_code;

	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg) + vchnl_req->parm_len;
	vchnl_msg->op_code = IRDMA_VCHNL_OP_GET_STATS;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_GET_STATS_V0;
	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev, 0, (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

/**
 * irdma_vchnl_vf_get_pe_stats - Get PE stats
 * @dev: RDMA device pointer
 * @hw_stats: HW stats struct
 */
enum irdma_status_code irdma_vchnl_vf_get_pe_stats(struct irdma_sc_dev *dev,
						   struct irdma_dev_hw_stats *hw_stats)
{
	struct irdma_virtchnl_req  vchnl_req = {};
	enum irdma_status_code ret_code;

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;

	vchnl_req.dev = dev;
	vchnl_req.parm = hw_stats;
	vchnl_req.parm_len = IRDMA_VF_STATS_SIZE_V0;
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;

	ret_code = vchnl_vf_send_get_pe_stats_req(dev, &vchnl_req);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: Send message failed 0x%0x\n", ret_code);
		return ret_code;
	}
	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;
	else
		return vchnl_req.ret_code;
}

/**
 * vchnl_vf_send_get_hmc_fcn_req - Request HMC Function from VF
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 */
static enum irdma_status_code vchnl_vf_send_get_hmc_fcn_req(struct irdma_sc_dev *dev,
							    struct irdma_virtchnl_req *vchnl_req)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;

	if (!dev->vchnl_up)
		return ret_code;

	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_GET_HMC_FCN;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_GET_HMC_FCN_V0;
	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev, 0, (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

/**
 * irdma_vchnl_vf_send_put_hmc_fcn_req - Free VF HMC Function
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 */
static enum irdma_status_code irdma_vchnl_vf_send_put_hmc_fcn_req(struct irdma_sc_dev *dev,
								  struct irdma_virtchnl_req *vchnl_req)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;

	if (!dev->vchnl_up)
		return ret_code;

	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_PUT_HMC_FCN;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_PUT_HMC_FCN_V0;
	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev, 0, (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

/**
 * vchnl_vf_send_add_hmc_objs_req - Add HMC objects
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 * @rsrc_type: resource type of the object
 * @start_index: index
 * @rsrc_count: resource count
 */
static enum irdma_status_code vchnl_vf_send_add_hmc_objs_req(struct irdma_sc_dev *dev,
							     struct irdma_virtchnl_req *vchnl_req,
							     enum irdma_hmc_rsrc_type rsrc_type,
							     u32 start_index,
							     u32 rsrc_count)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;
	struct irdma_virtchnl_hmc_obj_range *add_hmc_obj;

	if (!dev->vchnl_up)
		return ret_code;

	add_hmc_obj = (struct irdma_virtchnl_hmc_obj_range *)
		      vchnl_msg->buf;
	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	memset(add_hmc_obj, 0, sizeof(*add_hmc_obj));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg) + sizeof(struct irdma_virtchnl_hmc_obj_range);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_ADD_HMC_OBJ_RANGE;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_ADD_HMC_OBJ_RANGE_V0;
	add_hmc_obj->obj_type = (u16)rsrc_type;
	add_hmc_obj->start_index = start_index;
	add_hmc_obj->obj_count = rsrc_count;
	irdma_dbg(dev,
		  "VIRT: Sending message: obj_type = %d, start_index = %d, obj_count = %d\n",
		  add_hmc_obj->obj_type, add_hmc_obj->start_index,
		  add_hmc_obj->obj_count);

	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev, 0, (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

/**
 * vchnl_vf_send_del_hmc_objs_req - del HMC objects
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 * @rsrc_type: resource type to delete
 * @start_index: starting index for resource
 * @rsrc_count: number of resource type to delete
 */
static enum irdma_status_code vchnl_vf_send_del_hmc_objs_req(struct irdma_sc_dev *dev,
							     struct irdma_virtchnl_req *vchnl_req,
							     enum irdma_hmc_rsrc_type rsrc_type,
							     u32 start_index,
							     u32 rsrc_count)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;
	struct irdma_virtchnl_hmc_obj_range *add_hmc_obj;

	if (!dev->vchnl_up)
		return ret_code;

	add_hmc_obj = (struct irdma_virtchnl_hmc_obj_range *)
		      vchnl_msg->buf;
	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	memset(add_hmc_obj, 0, sizeof(*add_hmc_obj));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg) + sizeof(struct irdma_virtchnl_hmc_obj_range);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_DEL_HMC_OBJ_RANGE;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_DEL_HMC_OBJ_RANGE_V0;
	add_hmc_obj->obj_type = (u16)rsrc_type;
	add_hmc_obj->start_index = start_index;
	add_hmc_obj->obj_count = rsrc_count;
	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev, 0, (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

/**
 * vchnl_vf_manage_ws_node - manage ws node
 * @dev: RDMA device pointer
 * @vchnl_req: Virtual channel message request pointer
 * @add: Add or remove ws node
 * @user_pri: user priority of ws node
 */
enum irdma_status_code irdma_vchnl_vf_manage_ws_node(struct irdma_sc_dev *dev,
						     struct irdma_virtchnl_req *vchnl_req,
						     bool add,
						     u8 user_pri)
{
	enum irdma_status_code ret_code = IRDMA_ERR_NOT_READY;
	struct irdma_virtchnl_op_buf *vchnl_msg = vchnl_req->vchnl_msg;
	struct irdma_virtchnl_manage_ws_node *add_ws_node;

	if (!dev->vchnl_up)
		return ret_code;

	add_ws_node = (struct irdma_virtchnl_manage_ws_node *)vchnl_msg->buf;
	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	memset(add_ws_node, 0, sizeof(*add_ws_node));
	vchnl_msg->op_ctx = (uintptr_t)vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg) + sizeof(struct irdma_virtchnl_manage_ws_node);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_MANAGE_WS_NODE;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_MANAGE_WS_NODE_V0;
	add_ws_node->add = add;
	add_ws_node->user_pri = user_pri;
	irdma_dbg(dev,
		  "VIRT: Sending message: manage_ws_node add = %d, user_pri = %d\n",
		  add_ws_node->add, add_ws_node->user_pri);

	dev->vchnl_req = vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev, 0, (u8 *)vchnl_msg,
					    vchnl_msg->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
	return ret_code;
}

static enum irdma_hmc_rsrc_type hmc_rsrc_types_gen1[] = {
	IRDMA_HMC_IW_QP,
	IRDMA_HMC_IW_CQ,
	IRDMA_HMC_IW_HTE,
	IRDMA_HMC_IW_ARP,
	IRDMA_HMC_IW_APBVT_ENTRY,
	IRDMA_HMC_IW_MR,
	IRDMA_HMC_IW_XF,
	IRDMA_HMC_IW_XFFL,
	IRDMA_HMC_IW_Q1,
	IRDMA_HMC_IW_Q1FL,
	IRDMA_HMC_IW_TIMER,
	IRDMA_HMC_IW_PBLE
};

static enum irdma_hmc_rsrc_type hmc_rsrc_types_gen2[] = {
	IRDMA_HMC_IW_QP,
	IRDMA_HMC_IW_CQ,
	IRDMA_HMC_IW_HTE,
	IRDMA_HMC_IW_ARP,
	IRDMA_HMC_IW_APBVT_ENTRY,
	IRDMA_HMC_IW_MR,
	IRDMA_HMC_IW_XF,
	IRDMA_HMC_IW_XFFL,
	IRDMA_HMC_IW_Q1,
	IRDMA_HMC_IW_Q1FL,
	IRDMA_HMC_IW_TIMER,
	IRDMA_HMC_IW_FSIMC,
	IRDMA_HMC_IW_FSIAV,
	IRDMA_HMC_IW_PBLE,
	IRDMA_HMC_IW_RRF,
	IRDMA_HMC_IW_RRFFL,
	IRDMA_HMC_IW_HDR,
	IRDMA_HMC_IW_MD,
	IRDMA_HMC_IW_OOISC,
	IRDMA_HMC_IW_OOISCFFL,
};

/**
 * irdma_find_vf_dev - get vf struct pointer
 * @dev: shared device pointer
 * @vf_id: virtual function id
 */
struct irdma_vfdev *irdma_find_vf_dev(struct irdma_sc_dev *dev, u16 vf_id)
{
	struct irdma_vfdev *vf_dev = NULL;
	unsigned long flags;
	u16 iw_vf_idx;

	spin_lock_irqsave(&dev->vf_dev_lock, flags);
	for (iw_vf_idx = 0; iw_vf_idx < dev->num_vfs; iw_vf_idx++) {
		if (dev->vf_dev[iw_vf_idx] && dev->vf_dev[iw_vf_idx]->vf_id == vf_id) {
			vf_dev = dev->vf_dev[iw_vf_idx];
			refcount_inc(&vf_dev->refcnt);
			break;
		}
	}
	spin_unlock_irqrestore(&dev->vf_dev_lock, flags);

	return vf_dev;
}

/**
 * irdma_remove_vf_dev - remove vf_dev
 * @dev: shared device pointer
 * @vf_dev: vf dev to be removed
 */
void irdma_remove_vf_dev(struct irdma_sc_dev *dev, struct irdma_vfdev *vf_dev)
{
	unsigned long flags;

	spin_lock_irqsave(&dev->vf_dev_lock, flags);
	dev->vf_dev[vf_dev->iw_vf_idx] = NULL;
	spin_unlock_irqrestore(&dev->vf_dev_lock, flags);
}

/**
 * vchnl_pf_send_get_ver_resp - Send channel version to VF
 * @dev: RDMA device pointer
 * @vf_id: Virtual function ID associated with the message
 * @vchnl_msg: Virtual channel message buffer pointer
 * @param: parameter that is passed back to the VF
 * @param_len: length of parameter that's being passed in
 * @resp_code: response code sent back to VF
 */
static void vchnl_pf_send_resp(struct irdma_sc_dev *dev, u16 vf_id,
			       struct irdma_virtchnl_op_buf *vchnl_msg, void *param,
			       u16 param_len, enum irdma_status_code resp_code)
{
	enum irdma_status_code ret_code;
	u8 resp_buf[IRDMA_VCHNL_MAX_VF_MSG_SIZE] = {};
	struct irdma_virtchnl_resp_buf *vchnl_msg_resp;

	vchnl_msg_resp = (struct irdma_virtchnl_resp_buf *)resp_buf;
	vchnl_msg_resp->op_ctx = vchnl_msg->op_ctx;
	vchnl_msg_resp->buf_len = IRDMA_VCHNL_RESP_DEFAULT_SIZE + param_len;
	vchnl_msg_resp->op_ret_code = (s16)resp_code;
	if (param_len)
		memcpy(vchnl_msg_resp->buf, param, param_len);

	ret_code = dev->vchnl_if->vchnl_send(dev, vf_id, resp_buf, vchnl_msg_resp->buf_len);
	if (ret_code)
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
		          ret_code);
}

/**
 * pf_valid_hmc_rsrc_type - Check obj_type input validation
 * @hw_rev: hw version
 * @obj_type: type of hmc resource
 */
static bool pf_valid_hmc_rsrc_type(u8 hw_rev, u16 obj_type)
{
	enum irdma_hmc_rsrc_type *valid_rsrcs;
	u8 num_rsrcs, i;

	switch (hw_rev) {
	case IRDMA_GEN_1:
		valid_rsrcs = hmc_rsrc_types_gen1;
		num_rsrcs = ARRAY_SIZE(hmc_rsrc_types_gen1);
		break;
	case IRDMA_GEN_2:
		valid_rsrcs = hmc_rsrc_types_gen2;
		num_rsrcs = ARRAY_SIZE(hmc_rsrc_types_gen2);
		break;
	default:
		return false;
	}

	for (i = 0; i < num_rsrcs; i++) {
		if (obj_type == valid_rsrcs[i])
			return true;
	}

	return false;
}

/**
 * irdma_pf_add_hmc_obj - Add HMC Object for VF
 * @vf_dev: pointer to the vf_dev
 * @hmc_obj: hmc_obj to be added
 */
static enum irdma_status_code irdma_pf_add_hmc_obj(struct irdma_vfdev *vf_dev,
						   struct irdma_virtchnl_hmc_obj_range *hmc_obj)
{
	struct irdma_sc_dev *dev = vf_dev->pf_dev;
	struct irdma_hmc_info *hmc_info = &vf_dev->hmc_info;
	struct irdma_hmc_create_obj_info info = {};
	enum irdma_status_code ret_code;

	if (!vf_dev->pf_hmc_initialized) {
		ret_code = irdma_pf_init_vfhmc(vf_dev->pf_dev, (u8)vf_dev->pmf_index);
		if (ret_code)
			return ret_code;
		vf_dev->pf_hmc_initialized = true;
	}

	if (!pf_valid_hmc_rsrc_type(dev->hw_attrs.uk_attrs.hw_rev, hmc_obj->obj_type)) {
		irdma_dbg(dev,
			  "VIRT: invalid hmc_rsrc type detected. vf_id %d obj_type 0x%x\n",
			  vf_dev->vf_id, hmc_obj->obj_type);
		return IRDMA_ERR_PARAM;
	}

	info.hmc_info = hmc_info;
	info.privileged = false;
	info.rsrc_type = (u32)hmc_obj->obj_type;
	info.entry_type = (info.rsrc_type == IRDMA_HMC_IW_PBLE) ?
			  IRDMA_SD_TYPE_PAGED : IRDMA_SD_TYPE_DIRECT;
	info.start_idx = hmc_obj->start_index;
	info.count = hmc_obj->obj_count;
	irdma_dbg(vf_dev->pf_dev,
		  "VIRT: IRDMA_VCHNL_OP_ADD_HMC_OBJ_RANGE.  Add %u type %u objects\n",
		  info.count, info.rsrc_type);

	return irdma_sc_create_hmc_obj(vf_dev->pf_dev, &info);
}

/**
 * irdma_pf_del_hmc_obj - Delete HMC Object for VF
 * @vf_dev: pointer to the vf_dev
 * @hmc_obj: hmc_obj to be deleted
 */
static enum irdma_status_code irdma_pf_del_hmc_obj(struct irdma_vfdev *vf_dev,
						   struct irdma_virtchnl_hmc_obj_range *hmc_obj)
{
	struct irdma_sc_dev *dev = vf_dev->pf_dev;
	struct irdma_hmc_info *hmc_info = &vf_dev->hmc_info;
	struct irdma_hmc_del_obj_info info = {};

	if (!vf_dev->pf_hmc_initialized)
		return IRDMA_ERR_PARAM;

	if (!pf_valid_hmc_rsrc_type(dev->hw_attrs.uk_attrs.hw_rev, hmc_obj->obj_type)) {
		irdma_dbg(dev,
			  "VIRT: invalid hmc_rsrc type detected. vf_id %d obj_type 0x%x\n",
			  vf_dev->vf_id, hmc_obj->obj_type);
		return IRDMA_ERR_PARAM;
	}

	info.hmc_info = hmc_info;
	info.privileged = false;
	info.rsrc_type = (u32)hmc_obj->obj_type;
	info.start_idx = hmc_obj->start_index;
	info.count = hmc_obj->obj_count;
	irdma_dbg(vf_dev->pf_dev,
		  "VIRT: IRDMA_VCHNL_OP_DEL_HMC_OBJ_RANGE. Delete %u type %u objects\n",
		  info.count, info.rsrc_type);

	return irdma_sc_del_hmc_obj(vf_dev->pf_dev, &info, false);
}

/**
 * irdma_pf_manage_ws_node - managing ws node for VF
 * @vf_dev: pointer to the VF Device
 * @ws_node: work scheduler node to be modified
 * @qs_handle: returned qs_handle provided by cqp
 */
static enum irdma_status_code irdma_pf_manage_ws_node(struct irdma_vfdev *vf_dev,
						      struct irdma_virtchnl_manage_ws_node *ws_node,
						      u16 *qs_handle)
{
	enum irdma_status_code ret_code = 0;
	struct irdma_sc_vsi *vsi = vf_dev->vf_vsi;

	if (ws_node->user_pri >= IRDMA_MAX_USER_PRIORITY)
		return IRDMA_ERR_PARAM;

	irdma_dbg(vf_dev->pf_dev,
		  "VIRT: IRDMA_VCHNL_OP_MANAGE_WS_NODE. Add %d vf_id %d\n",
		  ws_node->add, vf_dev->vf_id);

	if (ws_node->add) {
		ret_code = vsi->dev->ws_add(vsi, ws_node->user_pri);
		if (ret_code) {
			irdma_dbg(vf_dev->pf_dev,
				  "VIRT: irdma_ws_add failed ret_code = %x\n",
				  ret_code);
		} else {
			*qs_handle = vsi->qos[ws_node->user_pri].qs_handle;
		}
	} else {
		vsi->dev->ws_remove(vsi, ws_node->user_pri);
	}

	return ret_code;
}

/**
 * irdma_set_hmc_fcn_info - Populate hmc_fcn_info struct
 * @vf_dev: pointer to VF dev structure
 * @hmc_fcn_info: pointer to HMC fcn info to be filled up
 */
static
void irdma_set_hmc_fcn_info(struct irdma_vfdev *vf_dev,
			    struct irdma_hmc_fcn_info *hmc_fcn_info)
{
	memset(hmc_fcn_info, 0, sizeof(*hmc_fcn_info));

	hmc_fcn_info->vf_id = vf_dev->vf_id;
}

/**
 * irdma_get_next_vf_idx - return the next vf_idx available
 * @dev: pointer to RDMA dev structure
 */
static u16 irdma_get_next_vf_idx(struct irdma_sc_dev *dev)
{
	u16 vf_idx;

	for (vf_idx = 0; vf_idx < dev->num_vfs; vf_idx++) {
		if (!dev->vf_dev[vf_idx])
			break;
	}

	return vf_idx < dev->num_vfs ? vf_idx : IRDMA_VCHNL_INVALID_VF_IDX;
}

/**
 * irdma_put_vfdev - put vfdev and free memory
 * @dev: pointer to RDMA dev structure
 * @vf_dev: pointer to RDMA vf dev structure
 */
void irdma_put_vfdev(struct irdma_sc_dev *dev, struct irdma_vfdev *vf_dev)
{
	if (refcount_dec_and_test(&vf_dev->refcnt)) {
		struct irdma_virt_mem virt_mem;

		if (vf_dev->hmc_info.sd_table.sd_entry) {
			virt_mem.va = vf_dev->hmc_info.sd_table.sd_entry;
			virt_mem.size = sizeof(struct irdma_hmc_sd_entry) *
					(vf_dev->hmc_info.sd_table.sd_cnt +
					 vf_dev->hmc_info.first_sd_index);
			kfree(virt_mem.va);
		}

		virt_mem.va = vf_dev;
		virt_mem.size = sizeof(*vf_dev);
		kfree(virt_mem.va);
		irdma_put_dev_ref(dev);
	}
}

/**
 * irdma_pf_get_vf_hmc_fcn - Get hmc fcn from CQP for VF
 * @dev: pointer to RDMA dev structure
 * @vf_id: vf id of the hmc fcn requester
 */
static struct irdma_vfdev *irdma_pf_get_vf_hmc_fcn(struct irdma_sc_dev *dev, u16 vf_id)
{
	struct irdma_hmc_fcn_info hmc_fcn_info;
	struct irdma_virt_mem virt_mem;
	struct irdma_vfdev *vf_dev;
	struct irdma_sc_vsi *vsi;
	u16 iw_vf_idx = 0;

	iw_vf_idx = irdma_get_next_vf_idx(dev);
	if (iw_vf_idx == IRDMA_VCHNL_INVALID_VF_IDX)
		return NULL;

	virt_mem.size = sizeof(struct irdma_vfdev) + sizeof(struct irdma_hmc_obj_info) * IRDMA_HMC_IW_MAX;
	virt_mem.va = kzalloc(virt_mem.size, GFP_KERNEL);

	if (!virt_mem.va) {
		irdma_dbg(dev,
			  "VIRT: VF%u Unable to allocate a VF device structure.\n",
			  vf_id);
		return NULL;
	}

	vf_dev = virt_mem.va;
	vf_dev->pf_dev = dev;
	vf_dev->vf_id = vf_id;
	vf_dev->iw_vf_idx = iw_vf_idx;
	vf_dev->pf_hmc_initialized = false;
	vf_dev->hmc_info.hmc_obj = (struct irdma_hmc_obj_info *)(&vf_dev[1]);
	refcount_set(&vf_dev->refcnt, 1);

	irdma_dbg(dev, "VIRT: vf_dev %p, hmc_info %p, hmc_obj %p\n", vf_dev,
		  &vf_dev->hmc_info, vf_dev->hmc_info.hmc_obj);
	dev->vf_dev[iw_vf_idx] = vf_dev;
	vsi = irdma_update_vsi_ctx(dev, vf_dev, true);
	if (!vsi) {
		irdma_dbg(dev, "VIRT: VF%u failed updating vsi ctx .\n",
			  vf_id);
			    dev->vf_dev[vf_dev->iw_vf_idx] = NULL;
			    kfree(virt_mem.va);
		return NULL;
	}

	vf_dev->vf_vsi = vsi;
	vsi->vf_id = (u16)vf_dev->vf_id;
	vsi->vf_dev = vf_dev;

	irdma_set_hmc_fcn_info(vf_dev, &hmc_fcn_info);
	if (irdma_cqp_manage_hmc_fcn_cmd(dev, &hmc_fcn_info, &vf_dev->pmf_index)) {
		irdma_update_vsi_ctx(dev, vf_dev, false);
		dev->vf_dev[vf_dev->iw_vf_idx] = NULL;
		kfree(virt_mem.va);
		irdma_dbg(dev,
			  "VIRT: VF%u error CQP Get HMC Function operation.\n",
			  vf_id);
		return NULL;
	}

	irdma_dbg(dev, "VIRT: HMC Function allocated = 0x%08x\n",
		  vf_dev->pmf_index);

	irdma_add_dev_ref(dev);

	/* Caller references vf_dev */
	refcount_inc(&vf_dev->refcnt);
	return vf_dev;
}

/**
 * irdma_pf_put_vf_hmc_fcn - Put hmc fcn from CQP for VF
 * @dev: pointer to RDMA dev structure
 * @vf_dev: vf dev structure
 */
static void irdma_pf_put_vf_hmc_fcn(struct irdma_sc_dev *dev, struct irdma_vfdev *vf_dev)
{
	struct irdma_hmc_fcn_info hmc_fcn_info;

	irdma_set_hmc_fcn_info(vf_dev, &hmc_fcn_info);
	hmc_fcn_info.free_fcn = true;
	if (irdma_cqp_manage_hmc_fcn_cmd(dev, &hmc_fcn_info, &vf_dev->pmf_index)) {
		irdma_dbg(dev,
			  "VIRT: VF%u error CQP Free HMC Function operation.\n",
			  vf_dev->vf_id);
	}

	irdma_remove_vf_dev(dev, vf_dev);
	irdma_update_vsi_ctx(dev, vf_dev, false);
	irdma_put_vfdev(dev, vf_dev);
}

/**
 * irdma_recv_pf_worker - PF receive worker processes inbound vchnl request
 * @work: work element for the vchnl request
 */
static void irdma_recv_pf_worker(struct work_struct *work)
{
	struct irdma_virtchnl_work *vchnl_work = container_of(work, struct irdma_virtchnl_work, work);
	struct irdma_virtchnl_op_buf *vchnl_msg = (struct irdma_virtchnl_op_buf *)&vchnl_work->vf_msg_buf;
	u16 vf_id = vchnl_work->vf_id, qs_handle = 0, resp_len = 0;
	void *param = vchnl_msg->buf, *resp_param = NULL;
	enum irdma_status_code resp_code = 0;
	struct irdma_sc_dev *dev = vchnl_work->dev;
	struct irdma_vfdev *vf_dev = NULL;
	struct irdma_virt_mem virt_mem;
	u8 vlan_parse_en;
	u32 vchnl_ver;

	irdma_dbg(dev, "VIRT: opcode %u", vchnl_msg->op_code);
	vf_dev = irdma_find_vf_dev(dev, vf_id);
	if (vf_dev && vf_dev->reset_en)
		goto free_work;

	switch (vchnl_msg->op_code) {
	case IRDMA_VCHNL_OP_GET_VER:
		vchnl_ver = IRDMA_VCHNL_CHNL_VER_V1;
		resp_code = vchnl_msg->op_ver == IRDMA_VCHNL_OP_GET_VER_V1 ?
						 IRDMA_SUCCESS : IRDMA_NOT_SUPPORTED;
		resp_param = &vchnl_ver;
		resp_len = sizeof(vchnl_ver);
		break;
	case IRDMA_VCHNL_OP_GET_HMC_FCN:
		if (!vf_dev) {
			vf_dev = irdma_pf_get_vf_hmc_fcn(dev, vf_id);
			if (!vf_dev) {
				resp_code = IRDMA_ERR_NO_VF_AVAILABLE;
				break;
			}
		}

		resp_param = &vf_dev->pmf_index;
		resp_len = sizeof(vf_dev->pmf_index);
		break;
	case IRDMA_VCHNL_OP_PUT_HMC_FCN:
		if (!vf_dev)
			goto free_work;

		irdma_pf_put_vf_hmc_fcn(dev, vf_dev);
		break;

	case IRDMA_VCHNL_OP_ADD_HMC_OBJ_RANGE:
		if (!vf_dev)
			goto free_work;

		resp_code = irdma_pf_add_hmc_obj(vf_dev, param);
		break;
	case IRDMA_VCHNL_OP_DEL_HMC_OBJ_RANGE:
		if (!vf_dev)
			goto free_work;

		resp_code = irdma_pf_del_hmc_obj(vf_dev, param);
		break;
	case IRDMA_VCHNL_OP_MANAGE_WS_NODE:
		if (!vf_dev)
			goto free_work;

		resp_code = irdma_pf_manage_ws_node(vf_dev, param, &qs_handle);
		resp_param = &qs_handle;
		resp_len = sizeof(qs_handle);
		break;
	case IRDMA_VCHNL_OP_VLAN_PARSING:
		if (!vf_dev)
			goto free_work;

		vlan_parse_en = (!dev->double_vlan_en && !vf_dev->port_vlan_en) ||
				dev->double_vlan_en;
		irdma_dbg(dev,
			  "VIRT: port_vlan_en = 0x%x vlan_parse_en = 0x%x\n",
			  vf_dev->port_vlan_en, vlan_parse_en);

		resp_param = &vlan_parse_en;
		resp_len = sizeof(vlan_parse_en);
		break;
	default:
		irdma_dbg(dev, "VIRT: Invalid OpCode 0x%x\n",
			  vchnl_msg->op_code);
		resp_code = IRDMA_ERR_NOT_IMPL;
	}

	vchnl_pf_send_resp(dev, vf_id, vchnl_msg, resp_param, resp_len, resp_code);
free_work:
	if (vf_dev)
		irdma_put_vfdev(dev, vf_dev);

	virt_mem.va = work;
	kfree(virt_mem.va);
}

/**
 * irdma_vchnl_pf_verify_msg - validate vf received vchannel message size
 * @vchnl_msg: inbound vf vchannel message
 * @len: length of the virtual channels message
 */
static bool irdma_vchnl_pf_verify_msg(struct irdma_virtchnl_op_buf *vchnl_msg,
				      u16 len)
{
	u16 op_code = vchnl_msg->op_code;

	if (len > IRDMA_VCHNL_MAX_VF_MSG_SIZE)
		return false;

	switch (op_code) {
	case IRDMA_VCHNL_OP_ADD_HMC_OBJ_RANGE:
	case IRDMA_VCHNL_OP_DEL_HMC_OBJ_RANGE:
		if (len < IRDMA_VCHNL_RESP_DEFAULT_SIZE +
			  sizeof(struct irdma_virtchnl_hmc_obj_range))
			return false;
		break;
	case IRDMA_VCHNL_OP_MANAGE_WS_NODE:
		if (len < IRDMA_VCHNL_RESP_DEFAULT_SIZE +
			  sizeof(struct irdma_virtchnl_manage_ws_node))
			return false;
		break;
	case IRDMA_VCHNL_OP_VLAN_PARSING:
		if (len < IRDMA_VCHNL_RESP_DEFAULT_SIZE)
			return false;
		break;
	default:
		if (len < IRDMA_VCHNL_RESP_DEFAULT_SIZE)
			return false;
	}

	return true;
}
/**
 * irdma_vchnl_recv_pf - Receive PF virtual channel messages
 * @dev: RDMA device pointer
 * @vf_id: Virtual function ID associated with the message
 * @msg: Virtual channel message buffer pointer
 * @len: Length of the virtual channels message
 */
enum irdma_status_code irdma_vchnl_recv_pf(struct irdma_sc_dev *dev, u16 vf_id,
					   u8 *msg, u16 len)
{
	struct irdma_virtchnl_work *work;
	struct irdma_virt_mem workmem;

	irdma_dbg(dev, "VIRT: VF%u: msg %p len %u chnl up %u", vf_id, msg,
		  len, dev->vchnl_up);

	if (!msg || !irdma_vchnl_pf_verify_msg((struct irdma_virtchnl_op_buf *)msg, len))
		return IRDMA_ERR_PARAM;

	if (!dev->vchnl_up)
		return IRDMA_ERR_NOT_READY;

	workmem.size = sizeof(struct irdma_virtchnl_work);
	workmem.va = kzalloc(workmem.size, GFP_KERNEL);
	if (!workmem.va)
		return IRDMA_ERR_NO_MEMORY;

	work = workmem.va;
	memcpy(&work->vf_msg_buf, msg, len);
	work->dev = dev;
	work->vf_id = vf_id;
	work->len = len;
	INIT_WORK(&work->work, irdma_recv_pf_worker);
	queue_work(dev->vchnl_wq, &work->work);

	return 0;
}

/**
 * irdma_vchnl_vf_verify_resp - Verify requested response size
 * @vchnl_req: vchnl message requested
 * @resp_len: response length sent from vchnl peer
 */
static enum irdma_status_code irdma_vchnl_vf_verify_resp(struct irdma_virtchnl_req *vchnl_req,
							 u16 resp_len)
{
	switch (vchnl_req->vchnl_msg->op_code) {
	case IRDMA_VCHNL_OP_GET_VER:
	case IRDMA_VCHNL_OP_GET_HMC_FCN:
	case IRDMA_VCHNL_OP_PUT_HMC_FCN:
	case IRDMA_VCHNL_OP_ADD_HMC_OBJ_RANGE:
	case IRDMA_VCHNL_OP_DEL_HMC_OBJ_RANGE:
	case IRDMA_VCHNL_OP_MANAGE_STATS_INST:
	case IRDMA_VCHNL_OP_MCG:
	case IRDMA_VCHNL_OP_UP_MAP:
	case IRDMA_VCHNL_OP_MANAGE_WS_NODE:
	case IRDMA_VCHNL_OP_GET_STATS:
	case IRDMA_VCHNL_OP_VLAN_PARSING:
		if (resp_len != vchnl_req->parm_len)
			return IRDMA_ERR_VF_MSG_ERROR;
		break;
	default:
		return IRDMA_ERR_VF_MSG_ERROR;
	}

	return 0;
}

/**
 * irdma_vchnl_recv_vf - Receive VF virtual channel messages
 * @dev: RDMA device pointer
 * @vf_id: Virtual function ID associated with the message
 * @msg: Virtual channel message buffer pointer
 * @len: Length of the virtual channels message
 */
enum irdma_status_code irdma_vchnl_recv_vf(struct irdma_sc_dev *dev, u16 vf_id,
					   u8 *msg, u16 len)
{
	struct irdma_virtchnl_resp_buf *vchnl_msg_resp = (struct irdma_virtchnl_resp_buf *)msg;
	struct irdma_virtchnl_req *vchnl_req;
	u16 resp_len;

	if (len < sizeof(*vchnl_msg_resp))
		return IRDMA_ERR_BUF_TOO_SHORT;

	vchnl_req = (struct irdma_virtchnl_req *)(uintptr_t)
		    vchnl_msg_resp->op_ctx;

	if ((uintptr_t)vchnl_req != (uintptr_t)dev->vchnl_req) {
		irdma_dbg(dev,
			  "VIRT: error vchnl context value does not match\n");
		vchnl_req->ret_code = IRDMA_ERR_VF_MSG_ERROR;
		return IRDMA_ERR_VF_MSG_ERROR;
	}

	if (len < sizeof(*vchnl_msg_resp) + vchnl_req->parm_len)
		resp_len = len - sizeof(*vchnl_msg_resp);
	else
		resp_len = vchnl_req->parm_len;

	if (irdma_vchnl_vf_verify_resp(vchnl_req, resp_len) != 0) {
		vchnl_req->ret_code = IRDMA_ERR_VF_MSG_ERROR;
		return IRDMA_ERR_VF_MSG_ERROR;
	}

	vchnl_req->ret_code = (enum irdma_status_code)vchnl_msg_resp->op_ret_code;

	if (!vchnl_req->ret_code && vchnl_req->parm_len && vchnl_req->parm && resp_len) {
		memcpy(vchnl_req->parm, vchnl_msg_resp->buf,
			     resp_len);
		vchnl_req->resp_len = resp_len;
		irdma_dbg(dev, "VIRT: Got response, data size %u\n", resp_len);
	}

	return 0;
}

/**
 * irdma_vchnl_vf_get_ver - Request Channel version
 * @dev: RDMA device pointer
 * @vchnl_ver: Virtual channel message version pointer
 */
enum irdma_status_code irdma_vchnl_vf_get_ver(struct irdma_sc_dev *dev, u32 *vchnl_ver)
{
	struct irdma_virtchnl_req vchnl_req = {};
	enum irdma_status_code ret_code;
	u8 vchnl_ver_resp;

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;

	vchnl_req.dev = dev;
	vchnl_req.parm = vchnl_ver;
	vchnl_req.parm_len = sizeof(*vchnl_ver);
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;

	ret_code = vchnl_vf_send_get_ver_req(dev, &vchnl_req);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: Send message failed 0x%0x\n", ret_code);
		return ret_code;
	}

	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;

	vchnl_ver_resp = *(u8*)vchnl_req.parm;
	if (vchnl_ver_resp > IRDMA_VCHNL_CHNL_VER_V1) {
		irdma_dbg(dev, "VIRT: %s unsupported vchnl version 0x%0x\n",
			  __func__, *(u32 *)vchnl_req.parm);
		return IRDMA_NOT_SUPPORTED;
	}

	*vchnl_ver = vchnl_ver_resp;
	return vchnl_req.ret_code;
}

/**
 * irdma_vchnl_vf_get_hmc_fcn - Request VF HMC Function
 * @dev: RDMA device pointer
 */
enum irdma_status_code irdma_vchnl_vf_get_hmc_fcn(struct irdma_sc_dev *dev)
{
	struct irdma_virtchnl_req vchnl_req = {};
	enum irdma_status_code ret_code;
	u16 hmc_fcn; /* Deprecated */

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;

	vchnl_req.dev = dev;
	vchnl_req.parm = &hmc_fcn;
	vchnl_req.parm_len = sizeof(hmc_fcn);
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;

	ret_code = vchnl_vf_send_get_hmc_fcn_req(dev, &vchnl_req);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: Send message failed 0x%0x\n", ret_code);
		return ret_code;
	}
	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;
	else
		return vchnl_req.ret_code;
}

/**
 * irdma_vchnl_vf_put_hmc_fcn - Free VF HMC Function
 * @dev: RDMA device pointer
 */
enum irdma_status_code irdma_vchnl_vf_put_hmc_fcn(struct irdma_sc_dev *dev)
{
	struct irdma_virtchnl_req vchnl_req = {};
	enum irdma_status_code ret_code;

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;
	vchnl_req.dev = dev;
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;

	ret_code = irdma_vchnl_vf_send_put_hmc_fcn_req(dev, &vchnl_req);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: Send message failed 0x%0x\n", ret_code);
		return ret_code;
	}
	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;
	else
		return vchnl_req.ret_code;
}

/**
 * irdma_vchnl_vf_add_hmc_objs - Add HMC Object
 * @dev: RDMA device pointer
 * @rsrc_type: HMC Resource type
 * @start_index: Starting index of the objects to be added
 * @rsrc_count: Number of resources to be added
 */
enum irdma_status_code irdma_vchnl_vf_add_hmc_objs(struct irdma_sc_dev *dev,
						   enum irdma_hmc_rsrc_type rsrc_type,
						   u32 start_index, u32 rsrc_count)
{
	struct irdma_virtchnl_req vchnl_req = {};
	enum irdma_status_code ret_code;

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;
	vchnl_req.dev = dev;
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;

	ret_code = vchnl_vf_send_add_hmc_objs_req(dev,
						  &vchnl_req,
						  rsrc_type,
						  start_index,
						  rsrc_count);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: Send message failed 0x%0x\n", ret_code);
		return ret_code;
	}
	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;
	else
		return vchnl_req.ret_code;
}

/**
 * irdma_vchnl_vf_del_hmc_obj - del HMC obj
 * @dev: RDMA device pointer
 * @rsrc_type: HMC Resource type
 * @start_index: Starting index of the object to delete
 * @rsrc_count: Number of resources to be delete
 */
enum irdma_status_code irdma_vchnl_vf_del_hmc_obj(struct irdma_sc_dev *dev,
						  enum irdma_hmc_rsrc_type rsrc_type,
						  u32 start_index, u32 rsrc_count)
{
	struct irdma_virtchnl_req vchnl_req = {};
	enum irdma_status_code ret_code;

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;
	vchnl_req.dev = dev;
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;

	ret_code = vchnl_vf_send_del_hmc_objs_req(dev,
						  &vchnl_req,
						  rsrc_type,
						  start_index,
						  rsrc_count);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: Send message failed 0x%0x\n", ret_code);
		return ret_code;
	}
	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;
	else
		return vchnl_req.ret_code;
}

/**
 * irdma_vchnl_vf_get_vlan_parsing_cfg - Find if vlan should be processed
 * @dev: Dev pointer
 * @vlan_parse_en: vlan parsing enabled
 */
enum irdma_status_code irdma_vchnl_vf_get_vlan_parsing_cfg(struct irdma_sc_dev *dev,
							   u8 *vlan_parse_en)
{
	enum irdma_status_code ret_code;
	struct irdma_virtchnl_req vchnl_req = {};
	struct irdma_virtchnl_op_buf *vchnl_msg;

	if (!irdma_vf_clear_to_send(dev))
		return IRDMA_ERR_TIMEOUT;

	vchnl_req.dev = dev;
	vchnl_req.parm = vlan_parse_en;
	vchnl_req.parm_len = sizeof(*vlan_parse_en);
	vchnl_req.vchnl_msg = (struct irdma_virtchnl_op_buf *)&dev->vf_msg_buf;
	vchnl_msg = vchnl_req.vchnl_msg;

	memset(vchnl_msg, 0, sizeof(*vchnl_msg));
	memcpy(vchnl_msg->buf, vlan_parse_en, sizeof(*vlan_parse_en));
	vchnl_msg->op_ctx = (uintptr_t)&vchnl_req;
	vchnl_msg->buf_len = sizeof(*vchnl_msg) + sizeof(*vlan_parse_en);
	vchnl_msg->op_code = IRDMA_VCHNL_OP_VLAN_PARSING;
	vchnl_msg->op_ver = IRDMA_VCHNL_OP_VLAN_PARSING_V0;
	dev->vchnl_req = &vchnl_req;
	ret_code = dev->vchnl_if->vchnl_send(dev,
					     0,
					     (u8 *)vchnl_msg,
					     vchnl_msg->buf_len);
	if (ret_code) {
		irdma_dbg(dev, "VIRT: virt channel send failed 0x%x\n",
			  ret_code);
		return ret_code;
	}

	ret_code = irdma_vf_wait_vchnl_resp(dev);
	if (ret_code)
		return ret_code;

	if (vchnl_req.ret_code)
		return vchnl_req.ret_code;

	*vlan_parse_en = *(u8 *)vchnl_req.parm;
	return ret_code;
}

