aboutsummaryrefslogtreecommitdiff
path: root/drivers/virt
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/virt')
-rw-r--r--drivers/virt/coco/sev-guest/sev-guest.c45
1 files changed, 25 insertions, 20 deletions
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 5cdc972e5ef4..c47e54b2a865 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -57,6 +57,11 @@ struct snp_guest_dev {
struct snp_secrets_page_layout *layout;
struct snp_req_data input;
+ union {
+ struct snp_report_req report;
+ struct snp_derived_key_req derived_key;
+ struct snp_ext_report_req ext_report;
+ } req;
u32 *os_area_msg_seqno;
u8 *vmpck;
};
@@ -475,8 +480,8 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
{
struct snp_guest_crypto *crypto = snp_dev->crypto;
+ struct snp_report_req *req = &snp_dev->req.report;
struct snp_report_resp *resp;
- struct snp_report_req req;
int rc, resp_len;
lockdep_assert_held(&snp_cmd_mutex);
@@ -484,7 +489,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
if (!arg->req_data || !arg->resp_data)
return -EINVAL;
- if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
+ if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
return -EFAULT;
/*
@@ -498,7 +503,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
return -ENOMEM;
rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
- SNP_MSG_REPORT_REQ, &req, sizeof(req), resp->data,
+ SNP_MSG_REPORT_REQ, req, sizeof(*req), resp->data,
resp_len);
if (rc)
goto e_free;
@@ -513,9 +518,9 @@ e_free:
static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
{
+ struct snp_derived_key_req *req = &snp_dev->req.derived_key;
struct snp_guest_crypto *crypto = snp_dev->crypto;
struct snp_derived_key_resp resp = {0};
- struct snp_derived_key_req req;
int rc, resp_len;
/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
u8 buf[64 + 16];
@@ -534,11 +539,11 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
if (sizeof(buf) < resp_len)
return -ENOMEM;
- if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
+ if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
return -EFAULT;
rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
- SNP_MSG_KEY_REQ, &req, sizeof(req), buf, resp_len);
+ SNP_MSG_KEY_REQ, req, sizeof(*req), buf, resp_len);
if (rc)
return rc;
@@ -554,8 +559,8 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
{
+ struct snp_ext_report_req *req = &snp_dev->req.ext_report;
struct snp_guest_crypto *crypto = snp_dev->crypto;
- struct snp_ext_report_req req;
struct snp_report_resp *resp;
int ret, npages = 0, resp_len;
@@ -564,18 +569,18 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
if (!arg->req_data || !arg->resp_data)
return -EINVAL;
- if (copy_from_user(&req, (void __user *)arg->req_data, sizeof(req)))
+ if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
return -EFAULT;
/* userspace does not want certificate data */
- if (!req.certs_len || !req.certs_address)
+ if (!req->certs_len || !req->certs_address)
goto cmd;
- if (req.certs_len > SEV_FW_BLOB_MAX_SIZE ||
- !IS_ALIGNED(req.certs_len, PAGE_SIZE))
+ if (req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
+ !IS_ALIGNED(req->certs_len, PAGE_SIZE))
return -EINVAL;
- if (!access_ok((const void __user *)req.certs_address, req.certs_len))
+ if (!access_ok((const void __user *)req->certs_address, req->certs_len))
return -EFAULT;
/*
@@ -584,8 +589,8 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
* the host. If host does not supply any certs in it, then copy
* zeros to indicate that certificate data was not provided.
*/
- memset(snp_dev->certs_data, 0, req.certs_len);
- npages = req.certs_len >> PAGE_SHIFT;
+ memset(snp_dev->certs_data, 0, req->certs_len);
+ npages = req->certs_len >> PAGE_SHIFT;
cmd:
/*
* The intermediate response buffer is used while decrypting the
@@ -599,14 +604,14 @@ cmd:
snp_dev->input.data_npages = npages;
ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg,
- SNP_MSG_REPORT_REQ, &req.data,
- sizeof(req.data), resp->data, resp_len);
+ SNP_MSG_REPORT_REQ, &req->data,
+ sizeof(req->data), resp->data, resp_len);
/* If certs length is invalid then copy the returned length */
if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
- req.certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
+ req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
- if (copy_to_user((void __user *)arg->req_data, &req, sizeof(req)))
+ if (copy_to_user((void __user *)arg->req_data, req, sizeof(*req)))
ret = -EFAULT;
}
@@ -614,8 +619,8 @@ cmd:
goto e_free;
if (npages &&
- copy_to_user((void __user *)req.certs_address, snp_dev->certs_data,
- req.certs_len)) {
+ copy_to_user((void __user *)req->certs_address, snp_dev->certs_data,
+ req->certs_len)) {
ret = -EFAULT;
goto e_free;
}