diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 62 |
1 files changed, 44 insertions, 18 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 9fbc70200cd0..211f57164cb6 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -52,6 +52,7 @@ struct tls_decrypt_arg { struct_group(inargs, bool zc; bool async; + bool async_done; u8 tail; ); @@ -274,22 +275,30 @@ static int tls_do_decryption(struct sock *sk, DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1); atomic_inc(&ctx->decrypt_pending); } else { + DECLARE_CRYPTO_WAIT(wait); + aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - crypto_req_done, &ctx->async_wait); + crypto_req_done, &wait); + ret = crypto_aead_decrypt(aead_req); + if (ret == -EINPROGRESS || ret == -EBUSY) + ret = crypto_wait_req(ret, &wait); + return ret; } ret = crypto_aead_decrypt(aead_req); + if (ret == -EINPROGRESS) + return 0; + if (ret == -EBUSY) { ret = tls_decrypt_async_wait(ctx); - ret = ret ?: -EINPROGRESS; + darg->async_done = true; + /* all completions have run, we're not doing async anymore */ + darg->async = false; + return ret; } - if (ret == -EINPROGRESS) { - if (darg->async) - return 0; - ret = crypto_wait_req(ret, &ctx->async_wait); - } + atomic_dec(&ctx->decrypt_pending); darg->async = false; return ret; @@ -1588,8 +1597,11 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, /* Prepare and submit AEAD request */ err = tls_do_decryption(sk, sgin, sgout, dctx->iv, data_len + prot->tail_size, aead_req, darg); - if (err) + if (err) { + if (darg->async_done) + goto exit_free_skb; goto exit_free_pages; + } darg->skb = clear_skb ?: tls_strp_msg(ctx); clear_skb = NULL; @@ -1601,6 +1613,9 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, return err; } + if (unlikely(darg->async_done)) + return 0; + if (prot->tail_size) darg->tail = dctx->tail; @@ -1772,7 +1787,8 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, u8 *control, size_t skip, size_t len, - bool is_peek) + bool is_peek, + bool *more) { struct sk_buff *skb = skb_peek(&ctx->rx_list); struct tls_msg *tlm; @@ -1785,7 +1801,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, err = tls_record_content_type(msg, tlm, control); if (err <= 0) - goto out; + goto more; if (skip < rxm->full_len) break; @@ -1803,12 +1819,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, err = tls_record_content_type(msg, tlm, control); if (err <= 0) - goto out; + goto more; err = skb_copy_datagram_msg(skb, rxm->offset + skip, msg, chunk); if (err < 0) - goto out; + goto more; len = len - chunk; copied = copied + chunk; @@ -1844,6 +1860,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, out: return copied ? : err; +more: + if (more) + *more = true; + goto out; } static bool @@ -1943,10 +1963,12 @@ int tls_sw_recvmsg(struct sock *sk, struct strp_msg *rxm; struct tls_msg *tlm; ssize_t copied = 0; + ssize_t peeked = 0; bool async = false; int target, err; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_peek = flags & MSG_PEEK; + bool rx_more = false; bool released = true; bool bpf_strp_enabled; bool zc_capable; @@ -1966,12 +1988,12 @@ int tls_sw_recvmsg(struct sock *sk, goto end; /* Process pending decrypted records. It must be non-zero-copy */ - err = process_rx_list(ctx, msg, &control, 0, len, is_peek); + err = process_rx_list(ctx, msg, &control, 0, len, is_peek, &rx_more); if (err < 0) goto end; copied = err; - if (len <= copied) + if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more) goto end; target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); @@ -2064,6 +2086,8 @@ put_on_rx_list: decrypted += chunk; len -= chunk; __skb_queue_tail(&ctx->rx_list, skb); + if (unlikely(control != TLS_RECORD_TYPE_DATA)) + break; continue; } @@ -2087,8 +2111,10 @@ put_on_rx_list: if (err < 0) goto put_on_rx_list_err; - if (is_peek) + if (is_peek) { + peeked += chunk; goto put_on_rx_list; + } if (partially_consumed) { rxm->offset += chunk; @@ -2127,11 +2153,11 @@ recv_end: /* Drain records from the rx_list & copy if required */ if (is_peek || is_kvec) - err = process_rx_list(ctx, msg, &control, copied, - decrypted, is_peek); + err = process_rx_list(ctx, msg, &control, copied + peeked, + decrypted - peeked, is_peek, NULL); else err = process_rx_list(ctx, msg, &control, 0, - async_copy_bytes, is_peek); + async_copy_bytes, is_peek, NULL); } copied += decrypted; |