diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 49 |
1 files changed, 39 insertions, 10 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 6205ad1a84c7..6a9875456f84 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -47,9 +47,13 @@ #include "tls.h" struct tls_decrypt_arg { + struct_group(inargs, bool zc; bool async; u8 tail; + ); + + struct sk_buff *skb; }; struct tls_decrypt_ctx { @@ -1412,6 +1416,7 @@ out: * ------------------------------------------------------------------- * zc | Zero-copy decrypt allowed | Zero-copy performed * async | Async decrypt allowed | Async crypto used / in progress + * skb | * | Output skb */ /* This function decrypts the input skb into either out_iov or in out_sg @@ -1551,12 +1556,17 @@ fallback_to_reg_recv: /* Prepare and submit AEAD request */ err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv, data_len + prot->tail_size, aead_req, darg); + if (err) + goto exit_free_pages; + + darg->skb = tls_strp_msg(ctx); if (darg->async) return 0; if (prot->tail_size) darg->tail = dctx->tail; +exit_free_pages: /* Release the pages in case iov was mapped to pages */ for (; pages > 0; pages--) put_page(sg_page(&sgout[pages])); @@ -1569,6 +1579,7 @@ static int tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, struct tls_decrypt_arg *darg) { + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); int err; if (tls_ctx->rx_conf != TLS_HW) @@ -1580,6 +1591,8 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, darg->zc = false; darg->async = false; + darg->skb = tls_strp_msg(ctx); + ctx->recv_pkt = NULL; return 1; } @@ -1604,8 +1617,11 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); return err; } - if (darg->async) + if (darg->async) { + if (darg->skb == ctx->recv_pkt) + ctx->recv_pkt = NULL; goto decrypt_next; + } /* If opportunistic TLS 1.3 ZC failed retry without ZC */ if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && darg->tail != TLS_RECORD_TYPE_DATA)) { @@ -1616,12 +1632,17 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, return tls_rx_one_record(sk, dest, darg); } + if (darg->skb == ctx->recv_pkt) + ctx->recv_pkt = NULL; + decrypt_done: - pad = tls_padding_length(prot, ctx->recv_pkt, darg); - if (pad < 0) + pad = tls_padding_length(prot, darg->skb, darg); + if (pad < 0) { + consume_skb(darg->skb); return pad; + } - rxm = strp_msg(ctx->recv_pkt); + rxm = strp_msg(darg->skb); rxm->full_len -= pad; rxm->offset += prot->prepend_size; rxm->full_len -= prot->overhead_size; @@ -1663,6 +1684,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) { + consume_skb(ctx->recv_pkt); ctx->recv_pkt = NULL; __strp_unpause(&ctx->strp); } @@ -1872,7 +1894,7 @@ int tls_sw_recvmsg(struct sock *sk, ctx->zc_capable; decrypted = 0; while (len && (decrypted + copied < target || ctx->recv_pkt)) { - struct tls_decrypt_arg darg = {}; + struct tls_decrypt_arg darg; int to_decrypt, chunk; err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo); @@ -1889,9 +1911,10 @@ int tls_sw_recvmsg(struct sock *sk, goto recv_end; } - skb = ctx->recv_pkt; - rxm = strp_msg(skb); - tlm = tls_msg(skb); + memset(&darg.inargs, 0, sizeof(darg.inargs)); + + rxm = strp_msg(ctx->recv_pkt); + tlm = tls_msg(ctx->recv_pkt); to_decrypt = rxm->full_len - prot->overhead_size; @@ -1911,6 +1934,10 @@ int tls_sw_recvmsg(struct sock *sk, goto recv_end; } + skb = darg.skb; + rxm = strp_msg(skb); + tlm = tls_msg(skb); + async |= darg.async; /* If the type of records being processed is not known yet, @@ -2051,21 +2078,23 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, if (!skb_queue_empty(&ctx->rx_list)) { skb = __skb_dequeue(&ctx->rx_list); } else { - struct tls_decrypt_arg darg = {}; + struct tls_decrypt_arg darg; err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo); if (err <= 0) goto splice_read_end; + memset(&darg.inargs, 0, sizeof(darg.inargs)); + err = tls_rx_one_record(sk, NULL, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto splice_read_end; } - skb = ctx->recv_pkt; tls_rx_rec_done(ctx); + skb = darg.skb; } rxm = strp_msg(skb); |