diff options
-rw-r--r-- | drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h | 1 | ||||
-rw-r--r-- | include/net/handshake.h | 5 | ||||
-rw-r--r-- | include/net/tls.h | 4 | ||||
-rw-r--r-- | include/net/tls_prot.h | 68 | ||||
-rw-r--r-- | include/trace/events/handshake.h | 160 | ||||
-rw-r--r-- | net/handshake/Makefile | 2 | ||||
-rw-r--r-- | net/handshake/alert.c | 110 | ||||
-rw-r--r-- | net/handshake/handshake.h | 6 | ||||
-rw-r--r-- | net/handshake/tlshd.c | 23 | ||||
-rw-r--r-- | net/handshake/trace.c | 2 | ||||
-rw-r--r-- | net/sunrpc/svcsock.c | 50 | ||||
-rw-r--r-- | net/sunrpc/xprtsock.c | 45 | ||||
-rw-r--r-- | net/tls/tls.h | 1 |
13 files changed, 431 insertions, 46 deletions
diff --git a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h index 68562a82d036..62f62bff74a5 100644 --- a/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h +++ b/drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h @@ -22,6 +22,7 @@ #include <crypto/internal/hash.h> #include <linux/tls.h> #include <net/tls.h> +#include <net/tls_prot.h> #include <net/tls_toe.h> #include "t4fw_api.h" diff --git a/include/net/handshake.h b/include/net/handshake.h index 2e26e436e85f..8ebd4f9ed26e 100644 --- a/include/net/handshake.h +++ b/include/net/handshake.h @@ -40,5 +40,10 @@ int tls_server_hello_x509(const struct tls_handshake_args *args, gfp_t flags); int tls_server_hello_psk(const struct tls_handshake_args *args, gfp_t flags); bool tls_handshake_cancel(struct sock *sk); +void tls_handshake_close(struct socket *sock); + +u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *msg); +void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, + u8 *level, u8 *description); #endif /* _NET_HANDSHAKE_H */ diff --git a/include/net/tls.h b/include/net/tls.h index 5e71dd3df8ca..06fca9160346 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -69,10 +69,6 @@ extern const struct tls_cipher_size_desc tls_cipher_size_desc[]; #define TLS_CRYPTO_INFO_READY(info) ((info)->cipher_type) -#define TLS_RECORD_TYPE_ALERT 0x15 -#define TLS_RECORD_TYPE_HANDSHAKE 0x16 -#define TLS_RECORD_TYPE_DATA 0x17 - #define TLS_AAD_SPACE_SIZE 13 #define MAX_IV_SIZE 16 diff --git a/include/net/tls_prot.h b/include/net/tls_prot.h new file mode 100644 index 000000000000..68a40756440b --- /dev/null +++ b/include/net/tls_prot.h @@ -0,0 +1,68 @@ +/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. + * + * TLS Protocol definitions + * + * From https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml + */ + +#ifndef _TLS_PROT_H +#define _TLS_PROT_H + +/* + * TLS Record protocol: ContentType + */ +enum { + TLS_RECORD_TYPE_CHANGE_CIPHER_SPEC = 20, + TLS_RECORD_TYPE_ALERT = 21, + TLS_RECORD_TYPE_HANDSHAKE = 22, + TLS_RECORD_TYPE_DATA = 23, + TLS_RECORD_TYPE_HEARTBEAT = 24, + TLS_RECORD_TYPE_TLS12_CID = 25, + TLS_RECORD_TYPE_ACK = 26, +}; + +/* + * TLS Alert protocol: AlertLevel + */ +enum { + TLS_ALERT_LEVEL_WARNING = 1, + TLS_ALERT_LEVEL_FATAL = 2, +}; + +/* + * TLS Alert protocol: AlertDescription + */ +enum { + TLS_ALERT_DESC_CLOSE_NOTIFY = 0, + TLS_ALERT_DESC_UNEXPECTED_MESSAGE = 10, + TLS_ALERT_DESC_BAD_RECORD_MAC = 20, + TLS_ALERT_DESC_RECORD_OVERFLOW = 22, + TLS_ALERT_DESC_HANDSHAKE_FAILURE = 40, + TLS_ALERT_DESC_BAD_CERTIFICATE = 42, + TLS_ALERT_DESC_UNSUPPORTED_CERTIFICATE = 43, + TLS_ALERT_DESC_CERTIFICATE_REVOKED = 44, + TLS_ALERT_DESC_CERTIFICATE_EXPIRED = 45, + TLS_ALERT_DESC_CERTIFICATE_UNKNOWN = 46, + TLS_ALERT_DESC_ILLEGAL_PARAMETER = 47, + TLS_ALERT_DESC_UNKNOWN_CA = 48, + TLS_ALERT_DESC_ACCESS_DENIED = 49, + TLS_ALERT_DESC_DECODE_ERROR = 50, + TLS_ALERT_DESC_DECRYPT_ERROR = 51, + TLS_ALERT_DESC_TOO_MANY_CIDS_REQUESTED = 52, + TLS_ALERT_DESC_PROTOCOL_VERSION = 70, + TLS_ALERT_DESC_INSUFFICIENT_SECURITY = 71, + TLS_ALERT_DESC_INTERNAL_ERROR = 80, + TLS_ALERT_DESC_INAPPROPRIATE_FALLBACK = 86, + TLS_ALERT_DESC_USER_CANCELED = 90, + TLS_ALERT_DESC_MISSING_EXTENSION = 109, + TLS_ALERT_DESC_UNSUPPORTED_EXTENSION = 110, + TLS_ALERT_DESC_UNRECOGNIZED_NAME = 112, + TLS_ALERT_DESC_BAD_CERTIFICATE_STATUS_RESPONSE = 113, + TLS_ALERT_DESC_UNKNOWN_PSK_IDENTITY = 115, + TLS_ALERT_DESC_CERTIFICATE_REQUIRED = 116, + TLS_ALERT_DESC_NO_APPLICATION_PROTOCOL = 120, +}; + +#endif /* _TLS_PROT_H */ diff --git a/include/trace/events/handshake.h b/include/trace/events/handshake.h index 8dadcab5f12a..bdd8a03cf5ba 100644 --- a/include/trace/events/handshake.h +++ b/include/trace/events/handshake.h @@ -6,7 +6,86 @@ #define _TRACE_HANDSHAKE_H #include <linux/net.h> +#include <net/tls_prot.h> #include <linux/tracepoint.h> +#include <trace/events/net_probe_common.h> + +#define TLS_RECORD_TYPE_LIST \ + record_type(CHANGE_CIPHER_SPEC) \ + record_type(ALERT) \ + record_type(HANDSHAKE) \ + record_type(DATA) \ + record_type(HEARTBEAT) \ + record_type(TLS12_CID) \ + record_type_end(ACK) + +#undef record_type +#undef record_type_end +#define record_type(x) TRACE_DEFINE_ENUM(TLS_RECORD_TYPE_##x); +#define record_type_end(x) TRACE_DEFINE_ENUM(TLS_RECORD_TYPE_##x); + +TLS_RECORD_TYPE_LIST + +#undef record_type +#undef record_type_end +#define record_type(x) { TLS_RECORD_TYPE_##x, #x }, +#define record_type_end(x) { TLS_RECORD_TYPE_##x, #x } + +#define show_tls_content_type(type) \ + __print_symbolic(type, TLS_RECORD_TYPE_LIST) + +TRACE_DEFINE_ENUM(TLS_ALERT_LEVEL_WARNING); +TRACE_DEFINE_ENUM(TLS_ALERT_LEVEL_FATAL); + +#define show_tls_alert_level(level) \ + __print_symbolic(level, \ + { TLS_ALERT_LEVEL_WARNING, "Warning" }, \ + { TLS_ALERT_LEVEL_FATAL, "Fatal" }) + +#define TLS_ALERT_DESCRIPTION_LIST \ + alert_description(CLOSE_NOTIFY) \ + alert_description(UNEXPECTED_MESSAGE) \ + alert_description(BAD_RECORD_MAC) \ + alert_description(RECORD_OVERFLOW) \ + alert_description(HANDSHAKE_FAILURE) \ + alert_description(BAD_CERTIFICATE) \ + alert_description(UNSUPPORTED_CERTIFICATE) \ + alert_description(CERTIFICATE_REVOKED) \ + alert_description(CERTIFICATE_EXPIRED) \ + alert_description(CERTIFICATE_UNKNOWN) \ + alert_description(ILLEGAL_PARAMETER) \ + alert_description(UNKNOWN_CA) \ + alert_description(ACCESS_DENIED) \ + alert_description(DECODE_ERROR) \ + alert_description(DECRYPT_ERROR) \ + alert_description(TOO_MANY_CIDS_REQUESTED) \ + alert_description(PROTOCOL_VERSION) \ + alert_description(INSUFFICIENT_SECURITY) \ + alert_description(INTERNAL_ERROR) \ + alert_description(INAPPROPRIATE_FALLBACK) \ + alert_description(USER_CANCELED) \ + alert_description(MISSING_EXTENSION) \ + alert_description(UNSUPPORTED_EXTENSION) \ + alert_description(UNRECOGNIZED_NAME) \ + alert_description(BAD_CERTIFICATE_STATUS_RESPONSE) \ + alert_description(UNKNOWN_PSK_IDENTITY) \ + alert_description(CERTIFICATE_REQUIRED) \ + alert_description_end(NO_APPLICATION_PROTOCOL) + +#undef alert_description +#undef alert_description_end +#define alert_description(x) TRACE_DEFINE_ENUM(TLS_ALERT_DESC_##x); +#define alert_description_end(x) TRACE_DEFINE_ENUM(TLS_ALERT_DESC_##x); + +TLS_ALERT_DESCRIPTION_LIST + +#undef alert_description +#undef alert_description_end +#define alert_description(x) { TLS_ALERT_DESC_##x, #x }, +#define alert_description_end(x) { TLS_ALERT_DESC_##x, #x } + +#define show_tls_alert_description(desc) \ + __print_symbolic(desc, TLS_ALERT_DESCRIPTION_LIST) DECLARE_EVENT_CLASS(handshake_event_class, TP_PROTO( @@ -106,6 +185,47 @@ DECLARE_EVENT_CLASS(handshake_error_class, ), \ TP_ARGS(net, req, sk, err)) +DECLARE_EVENT_CLASS(handshake_alert_class, + TP_PROTO( + const struct sock *sk, + unsigned char level, + unsigned char description + ), + TP_ARGS(sk, level, description), + TP_STRUCT__entry( + /* sockaddr_in6 is always bigger than sockaddr_in */ + __array(__u8, saddr, sizeof(struct sockaddr_in6)) + __array(__u8, daddr, sizeof(struct sockaddr_in6)) + __field(unsigned int, netns_ino) + __field(unsigned long, level) + __field(unsigned long, description) + ), + TP_fast_assign( + const struct inet_sock *inet = inet_sk(sk); + + memset(__entry->saddr, 0, sizeof(struct sockaddr_in6)); + memset(__entry->daddr, 0, sizeof(struct sockaddr_in6)); + TP_STORE_ADDR_PORTS(__entry, inet, sk); + + __entry->netns_ino = sock_net(sk)->ns.inum; + __entry->level = level; + __entry->description = description; + ), + TP_printk("src=%pISpc dest=%pISpc %s: %s", + __entry->saddr, __entry->daddr, + show_tls_alert_level(__entry->level), + show_tls_alert_description(__entry->description) + ) +); +#define DEFINE_HANDSHAKE_ALERT(name) \ + DEFINE_EVENT(handshake_alert_class, name, \ + TP_PROTO( \ + const struct sock *sk, \ + unsigned char level, \ + unsigned char description \ + ), \ + TP_ARGS(sk, level, description)) + /* * Request lifetime events @@ -154,6 +274,46 @@ DEFINE_HANDSHAKE_ERROR(handshake_cmd_accept_err); DEFINE_HANDSHAKE_FD_EVENT(handshake_cmd_done); DEFINE_HANDSHAKE_ERROR(handshake_cmd_done_err); +/* + * TLS Record events + */ + +TRACE_EVENT(tls_contenttype, + TP_PROTO( + const struct sock *sk, + unsigned char type + ), + TP_ARGS(sk, type), + TP_STRUCT__entry( + /* sockaddr_in6 is always bigger than sockaddr_in */ + __array(__u8, saddr, sizeof(struct sockaddr_in6)) + __array(__u8, daddr, sizeof(struct sockaddr_in6)) + __field(unsigned int, netns_ino) + __field(unsigned long, type) + ), + TP_fast_assign( + const struct inet_sock *inet = inet_sk(sk); + + memset(__entry->saddr, 0, sizeof(struct sockaddr_in6)); + memset(__entry->daddr, 0, sizeof(struct sockaddr_in6)); + TP_STORE_ADDR_PORTS(__entry, inet, sk); + + __entry->netns_ino = sock_net(sk)->ns.inum; + __entry->type = type; + ), + TP_printk("src=%pISpc dest=%pISpc %s", + __entry->saddr, __entry->daddr, + show_tls_content_type(__entry->type) + ) +); + +/* + * TLS Alert events + */ + +DEFINE_HANDSHAKE_ALERT(tls_alert_send); +DEFINE_HANDSHAKE_ALERT(tls_alert_recv); + #endif /* _TRACE_HANDSHAKE_H */ #include <trace/define_trace.h> diff --git a/net/handshake/Makefile b/net/handshake/Makefile index 247d73c6ff6e..ef4d9a2112bd 100644 --- a/net/handshake/Makefile +++ b/net/handshake/Makefile @@ -8,6 +8,6 @@ # obj-y += handshake.o -handshake-y := genl.o netlink.o request.o tlshd.o trace.o +handshake-y := alert.o genl.o netlink.o request.o tlshd.o trace.o obj-$(CONFIG_NET_HANDSHAKE_KUNIT_TEST) += handshake-test.o diff --git a/net/handshake/alert.c b/net/handshake/alert.c new file mode 100644 index 000000000000..329d91984683 --- /dev/null +++ b/net/handshake/alert.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Handle the TLS Alert protocol + * + * Author: Chuck Lever <chuck.lever@oracle.com> + * + * Copyright (c) 2023, Oracle and/or its affiliates. + */ + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/kernel.h> +#include <linux/module.h> +#include <linux/skbuff.h> +#include <linux/inet.h> + +#include <net/sock.h> +#include <net/handshake.h> +#include <net/tls.h> +#include <net/tls_prot.h> + +#include "handshake.h" + +#include <trace/events/handshake.h> + +/** + * tls_alert_send - send a TLS Alert on a kTLS socket + * @sock: open kTLS socket to send on + * @level: TLS Alert level + * @description: TLS Alert description + * + * Returns zero on success or a negative errno. + */ +int tls_alert_send(struct socket *sock, u8 level, u8 description) +{ + u8 record_type = TLS_RECORD_TYPE_ALERT; + u8 buf[CMSG_SPACE(sizeof(record_type))]; + struct msghdr msg = { 0 }; + struct cmsghdr *cmsg; + struct kvec iov; + u8 alert[2]; + int ret; + + trace_tls_alert_send(sock->sk, level, description); + + alert[0] = level; + alert[1] = description; + iov.iov_base = alert; + iov.iov_len = sizeof(alert); + + memset(buf, 0, sizeof(buf)); + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + msg.msg_flags = MSG_DONTWAIT; + + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_TLS; + cmsg->cmsg_type = TLS_SET_RECORD_TYPE; + cmsg->cmsg_len = CMSG_LEN(sizeof(record_type)); + memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type)); + + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len); + ret = sock_sendmsg(sock, &msg); + return ret < 0 ? ret : 0; +} + +/** + * tls_get_record_type - Look for TLS RECORD_TYPE information + * @sk: socket (for IP address information) + * @cmsg: incoming message to be parsed + * + * Returns zero or a TLS_RECORD_TYPE value. + */ +u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg) +{ + u8 record_type; + + if (cmsg->cmsg_level != SOL_TLS) + return 0; + if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE) + return 0; + + record_type = *((u8 *)CMSG_DATA(cmsg)); + trace_tls_contenttype(sk, record_type); + return record_type; +} +EXPORT_SYMBOL(tls_get_record_type); + +/** + * tls_alert_recv - Parse TLS Alert messages + * @sk: socket (for IP address information) + * @msg: incoming message to be parsed + * @level: OUT - TLS AlertLevel value + * @description: OUT - TLS AlertDescription value + * + */ +void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, + u8 *level, u8 *description) +{ + const struct kvec *iov; + u8 *data; + + iov = msg->msg_iter.kvec; + data = iov->iov_base; + *level = data[0]; + *description = data[1]; + + trace_tls_alert_recv(sk, *level, *description); +} +EXPORT_SYMBOL(tls_alert_recv); diff --git a/net/handshake/handshake.h b/net/handshake/handshake.h index 4dac965c99df..a48163765a7a 100644 --- a/net/handshake/handshake.h +++ b/net/handshake/handshake.h @@ -41,8 +41,11 @@ struct handshake_req { enum hr_flags_bits { HANDSHAKE_F_REQ_COMPLETED, + HANDSHAKE_F_REQ_SESSION, }; +struct genl_info; + /* Invariants for all handshake requests for one transport layer * security protocol */ @@ -63,6 +66,9 @@ enum hp_flags_bits { HANDSHAKE_F_PROTO_NOTIFY, }; +/* alert.c */ +int tls_alert_send(struct socket *sock, u8 level, u8 description); + /* netlink.c */ int handshake_genl_notify(struct net *net, const struct handshake_proto *proto, gfp_t flags); diff --git a/net/handshake/tlshd.c b/net/handshake/tlshd.c index b735f5cced2f..bbfb4095ddd6 100644 --- a/net/handshake/tlshd.c +++ b/net/handshake/tlshd.c @@ -18,6 +18,7 @@ #include <net/sock.h> #include <net/handshake.h> #include <net/genetlink.h> +#include <net/tls_prot.h> #include <uapi/linux/keyctl.h> #include <uapi/linux/handshake.h> @@ -100,6 +101,9 @@ static void tls_handshake_done(struct handshake_req *req, if (info) tls_handshake_remote_peerids(treq, info); + if (!status) + set_bit(HANDSHAKE_F_REQ_SESSION, &req->hr_flags); + treq->th_consumer_done(treq->th_consumer_data, -status, treq->th_peerid[0]); } @@ -424,3 +428,22 @@ bool tls_handshake_cancel(struct sock *sk) return handshake_req_cancel(sk); } EXPORT_SYMBOL(tls_handshake_cancel); + +/** + * tls_handshake_close - send a Closure alert + * @sock: an open socket + * + */ +void tls_handshake_close(struct socket *sock) +{ + struct handshake_req *req; + + req = handshake_req_hash_lookup(sock->sk); + if (!req) + return; + if (!test_and_clear_bit(HANDSHAKE_F_REQ_SESSION, &req->hr_flags)) + return; + tls_alert_send(sock, TLS_ALERT_LEVEL_WARNING, + TLS_ALERT_DESC_CLOSE_NOTIFY); +} +EXPORT_SYMBOL(tls_handshake_close); diff --git a/net/handshake/trace.c b/net/handshake/trace.c index 1c4d8e27e17a..44432d0857b9 100644 --- a/net/handshake/trace.c +++ b/net/handshake/trace.c @@ -8,8 +8,10 @@ */ #include <linux/types.h> +#include <linux/ipv6.h> #include <net/sock.h> +#include <net/inet_sock.h> #include <net/netlink.h> #include <net/genetlink.h> diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index e43f26382411..2ed29e40c6a9 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -43,7 +43,7 @@ #include <net/udp.h> #include <net/tcp.h> #include <net/tcp_states.h> -#include <net/tls.h> +#include <net/tls_prot.h> #include <net/handshake.h> #include <linux/uaccess.h> #include <linux/highmem.h> @@ -226,27 +226,30 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) } static int -svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg, +svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg, struct cmsghdr *cmsg, int ret) { - if (cmsg->cmsg_level == SOL_TLS && - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); - - switch (content_type) { - case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - msg->msg_flags &= ~MSG_EOR; - break; - case TLS_RECORD_TYPE_ALERT: - ret = -ENOTCONN; - break; - default: - ret = -EAGAIN; - } + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; + + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -ENOTCONN : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; } return ret; } @@ -258,13 +261,14 @@ svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg) struct cmsghdr cmsg; u8 buf[CMSG_SPACE(sizeof(u8))]; } u; + struct socket *sock = svsk->sk_sock; int ret; msg->msg_control = &u; msg->msg_controllen = sizeof(u); - ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT); + ret = sock_recvmsg(sock, msg, MSG_DONTWAIT); if (unlikely(msg->msg_controllen != sizeof(u))) - ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret); + ret = svc_tcp_sock_process_cmsg(sock, msg, &u.cmsg, ret); return ret; } @@ -1621,6 +1625,8 @@ static void svc_tcp_sock_detach(struct svc_xprt *xprt) { struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + tls_handshake_close(svsk->sk_sock); + svc_sock_detach(xprt); if (!test_bit(XPT_LISTENER, &xprt->xpt_flags)) { diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 9f010369100a..268a2cc61acd 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,7 +47,7 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> -#include <net/tls.h> +#include <net/tls_prot.h> #include <net/handshake.h> #include <linux/bvec.h> @@ -360,24 +360,27 @@ static int xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, struct cmsghdr *cmsg, int ret) { - if (cmsg->cmsg_level == SOL_TLS && - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); - - switch (content_type) { - case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - msg->msg_flags &= ~MSG_EOR; - break; - case TLS_RECORD_TYPE_ALERT: - ret = -ENOTCONN; - break; - default: - ret = -EAGAIN; - } + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; + + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -EACCES : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; } return ret; } @@ -777,6 +780,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport) } if (ret == -ESHUTDOWN) kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else if (ret == -EACCES) + xprt_wake_pending_tasks(&transport->xprt, -EACCES); else xs_poll_check_readable(transport); out: @@ -1292,6 +1297,8 @@ static void xs_close(struct rpc_xprt *xprt) dprintk("RPC: xs_close xprt %p\n", xprt); + if (transport->sock) + tls_handshake_close(transport->sock); xs_reset_transport(transport); xprt->reestablish_timeout = 0; } diff --git a/net/tls/tls.h b/net/tls/tls.h index 7e4d45537deb..37539ac3ac2a 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -39,6 +39,7 @@ #include <linux/types.h> #include <linux/skmsg.h> #include <net/tls.h> +#include <net/tls_prot.h> #define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER, \ TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT)) |