aboutsummaryrefslogtreecommitdiff
path: root/net/ipv4
diff options
context:
space:
mode:
authorCong Wang2021-03-30 19:32:31 -0700
committerAlexei Starovoitov2021-04-01 10:56:14 -0700
commit8a59f9d1e3d4340659fdfee8879dc09a6f2546e1 (patch)
tree22d59950e745c697e01c05b82ed67d7a9536bd88 /net/ipv4
parenta7ba4558e69a3c2ae4ca521f015832ef44799538 (diff)
sock: Introduce sk->sk_prot->psock_update_sk_prot()
Currently sockmap calls into each protocol to update the struct proto and replace it. This certainly won't work when the protocol is implemented as a module, for example, AF_UNIX. Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each protocol can implement its own way to replace the struct proto. This also helps get rid of symbol dependencies on CONFIG_INET. Signed-off-by: Cong Wang <cong.wang@bytedance.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org> Link: https://lore.kernel.org/bpf/20210331023237.41094-11-xiyou.wangcong@gmail.com
Diffstat (limited to 'net/ipv4')
-rw-r--r--net/ipv4/tcp_bpf.c24
-rw-r--r--net/ipv4/tcp_ipv4.c3
-rw-r--r--net/ipv4/udp.c3
-rw-r--r--net/ipv4/udp_bpf.c15
4 files changed, 40 insertions, 5 deletions
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index ae980716d896..ac8cfbaeacd2 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
}
-struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int tcp_bpf_update_proto(struct sock *sk, bool restore)
{
+ struct sk_psock *psock = sk_psock(sk);
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
+ if (restore) {
+ if (inet_csk_has_ulp(sk)) {
+ tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
+ } else {
+ sk->sk_write_space = psock->saved_write_space;
+ /* Pairs with lockless read in sk_clone_lock() */
+ WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+ }
+ return 0;
+ }
+
+ if (inet_csk_has_ulp(sk))
+ return -EINVAL;
+
if (sk->sk_family == AF_INET6) {
if (tcp_bpf_assert_proto_ops(psock->sk_proto))
- return ERR_PTR(-EINVAL);
+ return -EINVAL;
tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
}
- return &tcp_bpf_prots[family][config];
+ /* Pairs with lockless read in sk_clone_lock() */
+ WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
+ return 0;
}
+EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
/* If a child got cloned from a listening socket that had tcp_bpf
* protocol callbacks installed, we need to restore the callbacks to
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index daad4f99db32..dfc6d1c0e710 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
.hash = inet_hash,
.unhash = inet_unhash,
.get_port = inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+ .psock_update_sk_prot = tcp_bpf_update_proto,
+#endif
.enter_memory_pressure = tcp_enter_memory_pressure,
.leave_memory_pressure = tcp_leave_memory_pressure,
.stream_memory_free = tcp_stream_memory_free,
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 4a0478b17243..38952aaee3a1 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -2849,6 +2849,9 @@ struct proto udp_prot = {
.unhash = udp_lib_unhash,
.rehash = udp_v4_rehash,
.get_port = udp_v4_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+ .psock_update_sk_prot = udp_bpf_update_proto,
+#endif
.memory_allocated = &udp_memory_allocated,
.sysctl_mem = sysctl_udp_mem,
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
index 7a94791efc1a..6001f93cd3a0 100644
--- a/net/ipv4/udp_bpf.c
+++ b/net/ipv4/udp_bpf.c
@@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
}
core_initcall(udp_bpf_v4_build_proto);
-struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int udp_bpf_update_proto(struct sock *sk, bool restore)
{
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
+ struct sk_psock *psock = sk_psock(sk);
+
+ if (restore) {
+ sk->sk_write_space = psock->saved_write_space;
+ /* Pairs with lockless read in sk_clone_lock() */
+ WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+ return 0;
+ }
if (sk->sk_family == AF_INET6)
udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
- return &udp_bpf_prots[family];
+ /* Pairs with lockless read in sk_clone_lock() */
+ WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
+ return 0;
}
+EXPORT_SYMBOL_GPL(udp_bpf_update_proto);