diff options
author | Jakub Kicinski | 2022-04-08 11:31:25 -0700 |
---|---|---|
committer | David S. Miller | 2022-04-10 17:32:11 +0100 |
commit | 9bdf75ccffa690237cd0b472cd598cf6d22873dc (patch) | |
tree | a0bc324249adf544edabc9e73a67061ddd07b9b2 /net/tls | |
parent | d4bd88e67666c73cfa9d75c282e708890d4f10a7 (diff) |
tls: rx: don't report text length from the bowels of decrypt
We plumb pointer to chunk all the way to the decryption method.
It's set to the length of the text when decrypt_skb_update()
returns.
I think the code is written this way because original TLS
implementation passed &chunk to zerocopy_from_iter() and this
was carried forward as the code gotten more complex, without
any refactoring.
The fix for peek() introduced a new variable - to_decrypt
which for all practical purposes is what chunk is going to
get set to. Spare ourselves the pointer passing, use to_decrypt.
Use this opportunity to clean things up a little further.
Note that chunk / to_decrypt was mostly needed for the async
path, since the sync path would access rxm->full_len (decryption
transforms full_len from record size to text size). Use the
right source of truth more explicitly.
We have three cases:
- async - it's TLS 1.2 only, so chunk == to_decrypt, but we
need the min() because to_decrypt is a whole record
and we don't want to underflow len. Note that we can't
handle partial record by falling back to sync as it
would introduce reordering against records in flight.
- zc - again, TLS 1.2 only for now, so chunk == to_decrypt,
we don't do zc if len < to_decrypt, no need to check again.
- normal - it already handles chunk > len, we can factor out the
assignment to rxm->full_len and share it with zc.
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_sw.c | 33 |
1 files changed, 14 insertions, 19 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 86f77f8b825e..c321c5f85fbe 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1412,7 +1412,7 @@ out: static int decrypt_internal(struct sock *sk, struct sk_buff *skb, struct iov_iter *out_iov, struct scatterlist *out_sg, - int *chunk, bool *zc, bool async) + bool *zc, bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); @@ -1526,7 +1526,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, (n_sgout - 1)); if (err < 0) goto fallback_to_reg_recv; - *chunk = data_len; } else if (out_sg) { memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); } else { @@ -1536,7 +1535,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, fallback_to_reg_recv: sgout = sgin; pages = 0; - *chunk = data_len; *zc = false; } @@ -1555,8 +1553,7 @@ fallback_to_reg_recv: } static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct iov_iter *dest, int *chunk, bool *zc, - bool async) + struct iov_iter *dest, bool *zc, bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_prot_info *prot = &tls_ctx->prot_info; @@ -1580,7 +1577,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, } } - err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async); + err = decrypt_internal(sk, skb, dest, NULL, zc, async); if (err < 0) { if (err == -EINPROGRESS) tls_advance_record_sn(sk, prot, &tls_ctx->rx); @@ -1607,9 +1604,8 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgout) { bool zc = true; - int chunk; - return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false); + return decrypt_internal(sk, skb, NULL, sgout, &zc, false); } static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, @@ -1799,9 +1795,8 @@ int tls_sw_recvmsg(struct sock *sk, num_async = 0; while (len && (decrypted + copied < target || ctx->recv_pkt)) { bool retain_skb = false; + int to_decrypt, chunk; bool zc = false; - int to_decrypt; - int chunk = 0; bool async_capable; bool async = false; @@ -1838,7 +1833,7 @@ int tls_sw_recvmsg(struct sock *sk, async_capable = false; err = decrypt_skb_update(sk, skb, &msg->msg_iter, - &chunk, &zc, async_capable); + &zc, async_capable); if (err < 0 && err != -EINPROGRESS) { tls_err_abort(sk, -EBADMSG); goto recv_end; @@ -1876,8 +1871,13 @@ int tls_sw_recvmsg(struct sock *sk, } } - if (async) + if (async) { + /* TLS 1.2-only, to_decrypt must be text length */ + chunk = min_t(int, to_decrypt, len); goto pick_next_record; + } + /* TLS 1.3 may have updated the length by more than overhead */ + chunk = rxm->full_len; if (!zc) { if (bpf_strp_enabled) { @@ -1893,11 +1893,9 @@ int tls_sw_recvmsg(struct sock *sk, } } - if (rxm->full_len > len) { + if (chunk > len) { retain_skb = true; chunk = len; - } else { - chunk = rxm->full_len; } err = skb_copy_datagram_msg(skb, rxm->offset, @@ -1912,9 +1910,6 @@ int tls_sw_recvmsg(struct sock *sk, } pick_next_record: - if (chunk > len) - chunk = len; - decrypted += chunk; len -= chunk; @@ -2016,7 +2011,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, if (!skb) goto splice_read_end; - err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); + err = decrypt_skb_update(sk, skb, NULL, &zc, false); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto splice_read_end; |