diff options
Diffstat (limited to 'net/ipv4')
38 files changed, 1161 insertions, 548 deletions
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 3ca0cc467886..e2c219382345 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -1219,6 +1219,7 @@ EXPORT_SYMBOL(inet_unregister_protosw); static int inet_sk_reselect_saddr(struct sock *sk) { + struct inet_bind_hashbucket *prev_addr_hashbucket; struct inet_sock *inet = inet_sk(sk); __be32 old_saddr = inet->inet_saddr; __be32 daddr = inet->inet_daddr; @@ -1226,6 +1227,7 @@ static int inet_sk_reselect_saddr(struct sock *sk) struct rtable *rt; __be32 new_saddr; struct ip_options_rcu *inet_opt; + int err; inet_opt = rcu_dereference_protected(inet->inet_opt, lockdep_sock_is_held(sk)); @@ -1240,20 +1242,34 @@ static int inet_sk_reselect_saddr(struct sock *sk) if (IS_ERR(rt)) return PTR_ERR(rt); - sk_setup_caps(sk, &rt->dst); - new_saddr = fl4->saddr; - if (new_saddr == old_saddr) + if (new_saddr == old_saddr) { + sk_setup_caps(sk, &rt->dst); return 0; + } + + prev_addr_hashbucket = + inet_bhashfn_portaddr(tcp_or_dccp_get_hashinfo(sk), sk, + sock_net(sk), inet->inet_num); + + inet->inet_saddr = inet->inet_rcv_saddr = new_saddr; + + err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); + if (err) { + inet->inet_saddr = old_saddr; + inet->inet_rcv_saddr = old_saddr; + ip_rt_put(rt); + return err; + } + + sk_setup_caps(sk, &rt->dst); if (READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) > 1) { pr_info("%s(): shifting inet->saddr from %pI4 to %pI4\n", __func__, &old_saddr, &new_saddr); } - inet->inet_saddr = inet->inet_rcv_saddr = new_saddr; - /* * XXX The only one ugly spot where we need to * XXX really change the sockets identity after @@ -1448,12 +1464,9 @@ struct sk_buff *inet_gro_receive(struct list_head *head, struct sk_buff *skb) off = skb_gro_offset(skb); hlen = off + sizeof(*iph); - iph = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, hlen)) { - iph = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!iph)) - goto out; - } + iph = skb_gro_header(skb, hlen, off); + if (unlikely(!iph)) + goto out; proto = iph->protocol; diff --git a/net/ipv4/ah4.c b/net/ipv4/ah4.c index f8ad04470d3a..ee4e578c7f20 100644 --- a/net/ipv4/ah4.c +++ b/net/ipv4/ah4.c @@ -471,30 +471,38 @@ static int ah4_err(struct sk_buff *skb, u32 info) return 0; } -static int ah_init_state(struct xfrm_state *x) +static int ah_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) { struct ah_data *ahp = NULL; struct xfrm_algo_desc *aalg_desc; struct crypto_ahash *ahash; - if (!x->aalg) + if (!x->aalg) { + NL_SET_ERR_MSG(extack, "AH requires a state with an AUTH algorithm"); goto error; + } - if (x->encap) + if (x->encap) { + NL_SET_ERR_MSG(extack, "AH is not compatible with encapsulation"); goto error; + } ahp = kzalloc(sizeof(*ahp), GFP_KERNEL); if (!ahp) return -ENOMEM; ahash = crypto_alloc_ahash(x->aalg->alg_name, 0, 0); - if (IS_ERR(ahash)) + if (IS_ERR(ahash)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; + } ahp->ahash = ahash; if (crypto_ahash_setkey(ahash, x->aalg->alg_key, - (x->aalg->alg_key_len + 7) / 8)) + (x->aalg->alg_key_len + 7) / 8)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; + } /* * Lookup the algorithm description maintained by xfrm_algo, @@ -507,10 +515,7 @@ static int ah_init_state(struct xfrm_state *x) if (aalg_desc->uinfo.auth.icv_fullbits/8 != crypto_ahash_digestsize(ahash)) { - pr_info("%s: %s digestsize %u != %u\n", - __func__, x->aalg->alg_name, - crypto_ahash_digestsize(ahash), - aalg_desc->uinfo.auth.icv_fullbits / 8); + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; } diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c index 87c7e3fc5197..4f7237661afb 100644 --- a/net/ipv4/arp.c +++ b/net/ipv4/arp.c @@ -1129,7 +1129,7 @@ static int arp_req_get(struct arpreq *r, struct net_device *dev) r->arp_flags = arp_state_to_flags(neigh); read_unlock_bh(&neigh->lock); r->arp_ha.sa_family = dev->type; - strlcpy(r->arp_dev, dev->name, sizeof(r->arp_dev)); + strscpy(r->arp_dev, dev->name, sizeof(r->arp_dev)); err = 0; } neigh_release(neigh); diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index 85a9e500c42d..6da16ae6a962 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -124,7 +124,7 @@ static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log, return -EACCES; } - return NOT_INIT; + return 0; } BPF_CALL_2(bpf_tcp_send_ack, struct tcp_sock *, tp, u32, rcv_nxt) diff --git a/net/ipv4/datagram.c b/net/ipv4/datagram.c index ffd57523331f..405a8c2aea64 100644 --- a/net/ipv4/datagram.c +++ b/net/ipv4/datagram.c @@ -42,6 +42,8 @@ int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len oif = inet->mc_index; if (!saddr) saddr = inet->mc_addr; + } else if (!oif) { + oif = inet->uc_index; } fl4 = &inet->cork.fl.u.ip4; rt = ip_route_connect(fl4, usin->sin_addr.s_addr, saddr, oif, diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c index 5c03eba787e5..52c8047efedb 100644 --- a/net/ipv4/esp4.c +++ b/net/ipv4/esp4.c @@ -134,6 +134,7 @@ static void esp_free_tcp_sk(struct rcu_head *head) static struct sock *esp_find_tcp_sk(struct xfrm_state *x) { struct xfrm_encap_tmpl *encap = x->encap; + struct net *net = xs_net(x); struct esp_tcp_sk *esk; __be16 sport, dport; struct sock *nsk; @@ -160,7 +161,7 @@ static struct sock *esp_find_tcp_sk(struct xfrm_state *x) } spin_unlock_bh(&x->lock); - sk = inet_lookup_established(xs_net(x), &tcp_hashinfo, x->id.daddr.a4, + sk = inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, x->id.daddr.a4, dport, x->props.saddr.a4, sport, 0); if (!sk) return ERR_PTR(-ENOENT); @@ -1007,16 +1008,17 @@ static void esp_destroy(struct xfrm_state *x) crypto_free_aead(aead); } -static int esp_init_aead(struct xfrm_state *x) +static int esp_init_aead(struct xfrm_state *x, struct netlink_ext_ack *extack) { char aead_name[CRYPTO_MAX_ALG_NAME]; struct crypto_aead *aead; int err; - err = -ENAMETOOLONG; if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)", - x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME) - goto error; + x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + return -ENAMETOOLONG; + } aead = crypto_alloc_aead(aead_name, 0, 0); err = PTR_ERR(aead); @@ -1034,11 +1036,15 @@ static int esp_init_aead(struct xfrm_state *x) if (err) goto error; + return 0; + error: + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); return err; } -static int esp_init_authenc(struct xfrm_state *x) +static int esp_init_authenc(struct xfrm_state *x, + struct netlink_ext_ack *extack) { struct crypto_aead *aead; struct crypto_authenc_key_param *param; @@ -1049,10 +1055,6 @@ static int esp_init_authenc(struct xfrm_state *x) unsigned int keylen; int err; - err = -EINVAL; - if (!x->ealg) - goto error; - err = -ENAMETOOLONG; if ((x->props.flags & XFRM_STATE_ESN)) { @@ -1061,22 +1063,28 @@ static int esp_init_authenc(struct xfrm_state *x) x->geniv ?: "", x->geniv ? "(" : "", x->aalg ? x->aalg->alg_name : "digest_null", x->ealg->alg_name, - x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); goto error; + } } else { if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME, "%s%sauthenc(%s,%s)%s", x->geniv ?: "", x->geniv ? "(" : "", x->aalg ? x->aalg->alg_name : "digest_null", x->ealg->alg_name, - x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); goto error; + } } aead = crypto_alloc_aead(authenc_name, 0, 0); err = PTR_ERR(aead); - if (IS_ERR(aead)) + if (IS_ERR(aead)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; + } x->data = aead; @@ -1106,17 +1114,16 @@ static int esp_init_authenc(struct xfrm_state *x) err = -EINVAL; if (aalg_desc->uinfo.auth.icv_fullbits / 8 != crypto_aead_authsize(aead)) { - pr_info("ESP: %s digestsize %u != %u\n", - x->aalg->alg_name, - crypto_aead_authsize(aead), - aalg_desc->uinfo.auth.icv_fullbits / 8); + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto free_key; } err = crypto_aead_setauthsize( aead, x->aalg->alg_trunc_len / 8); - if (err) + if (err) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto free_key; + } } param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8); @@ -1131,7 +1138,7 @@ error: return err; } -static int esp_init_state(struct xfrm_state *x) +static int esp_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) { struct crypto_aead *aead; u32 align; @@ -1139,10 +1146,14 @@ static int esp_init_state(struct xfrm_state *x) x->data = NULL; - if (x->aead) - err = esp_init_aead(x); - else - err = esp_init_authenc(x); + if (x->aead) { + err = esp_init_aead(x, extack); + } else if (x->ealg) { + err = esp_init_authenc(x, extack); + } else { + NL_SET_ERR_MSG(extack, "ESP: AEAD or CRYPT must be provided"); + err = -EINVAL; + } if (err) goto error; @@ -1160,6 +1171,7 @@ static int esp_init_state(struct xfrm_state *x) switch (encap->encap_type) { default: + NL_SET_ERR_MSG(extack, "Unsupported encapsulation type for ESP"); err = -EINVAL; goto error; case UDP_ENCAP_ESPINUDP: diff --git a/net/ipv4/esp4_offload.c b/net/ipv4/esp4_offload.c index 935026f4c807..170152772d33 100644 --- a/net/ipv4/esp4_offload.c +++ b/net/ipv4/esp4_offload.c @@ -110,7 +110,10 @@ static struct sk_buff *xfrm4_tunnel_gso_segment(struct xfrm_state *x, struct sk_buff *skb, netdev_features_t features) { - return skb_eth_gso_segment(skb, features, htons(ETH_P_IP)); + __be16 type = x->inner_mode.family == AF_INET6 ? htons(ETH_P_IPV6) + : htons(ETH_P_IP); + + return skb_eth_gso_segment(skb, features, type); } static struct sk_buff *xfrm4_transport_gso_segment(struct xfrm_state *x, diff --git a/net/ipv4/fou.c b/net/ipv4/fou.c index 025a33c1b04d..0c3c6d0cee29 100644 --- a/net/ipv4/fou.c +++ b/net/ipv4/fou.c @@ -323,12 +323,9 @@ static struct sk_buff *gue_gro_receive(struct sock *sk, off = skb_gro_offset(skb); len = off + sizeof(*guehdr); - guehdr = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, len)) { - guehdr = skb_gro_header_slow(skb, len, off); - if (unlikely(!guehdr)) - goto out; - } + guehdr = skb_gro_header(skb, len, off); + if (unlikely(!guehdr)) + goto out; switch (guehdr->version) { case 0: @@ -931,6 +928,7 @@ static struct genl_family fou_nl_family __ro_after_init = { .module = THIS_MODULE, .small_ops = fou_nl_ops, .n_small_ops = ARRAY_SIZE(fou_nl_ops), + .resv_start_op = FOU_CMD_GET + 1, }; size_t fou_encap_hlen(struct ip_tunnel_encap *e) diff --git a/net/ipv4/gre_offload.c b/net/ipv4/gre_offload.c index 07073fa35205..2b9cb5398335 100644 --- a/net/ipv4/gre_offload.c +++ b/net/ipv4/gre_offload.c @@ -137,12 +137,9 @@ static struct sk_buff *gre_gro_receive(struct list_head *head, off = skb_gro_offset(skb); hlen = off + sizeof(*greh); - greh = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, hlen)) { - greh = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!greh)) - goto out; - } + greh = skb_gro_header(skb, hlen, off); + if (unlikely(!greh)) + goto out; /* Only support version 0 and K (key), C (csum) flags. Note that * although the support for the S (seq#) flag can be added easily diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c index e3ab0cb61624..df0660d818ac 100644 --- a/net/ipv4/igmp.c +++ b/net/ipv4/igmp.c @@ -2529,11 +2529,10 @@ done: err = ip_mc_leave_group(sk, &imr); return err; } - int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, - struct ip_msfilter __user *optval, int __user *optlen) + sockptr_t optval, sockptr_t optlen) { - int err, len, count, copycount; + int err, len, count, copycount, msf_size; struct ip_mreqn imr; __be32 addr = msf->imsf_multiaddr; struct ip_mc_socklist *pmc; @@ -2575,12 +2574,15 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, copycount = count < msf->imsf_numsrc ? count : msf->imsf_numsrc; len = flex_array_size(psl, sl_addr, copycount); msf->imsf_numsrc = count; - if (put_user(IP_MSFILTER_SIZE(copycount), optlen) || - copy_to_user(optval, msf, IP_MSFILTER_SIZE(0))) { + msf_size = IP_MSFILTER_SIZE(copycount); + if (copy_to_sockptr(optlen, &msf_size, sizeof(int)) || + copy_to_sockptr(optval, msf, IP_MSFILTER_SIZE(0))) { return -EFAULT; } if (len && - copy_to_user(&optval->imsf_slist_flex[0], psl->sl_addr, len)) + copy_to_sockptr_offset(optval, + offsetof(struct ip_msfilter, imsf_slist_flex), + psl->sl_addr, len)) return -EFAULT; return 0; done: @@ -2588,7 +2590,7 @@ done: } int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, - struct sockaddr_storage __user *p) + sockptr_t optval, size_t ss_offset) { int i, count, copycount; struct sockaddr_in *psin; @@ -2618,15 +2620,17 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, count = psl ? psl->sl_count : 0; copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc; gsf->gf_numsrc = count; - for (i = 0; i < copycount; i++, p++) { + for (i = 0; i < copycount; i++) { struct sockaddr_storage ss; psin = (struct sockaddr_in *)&ss; memset(&ss, 0, sizeof(ss)); psin->sin_family = AF_INET; psin->sin_addr.s_addr = psl->sl_addr[i]; - if (copy_to_user(p, &ss, sizeof(ss))) + if (copy_to_sockptr_offset(optval, ss_offset, + &ss, sizeof(ss))) return -EFAULT; + ss_offset += sizeof(ss); } return 0; } diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index eb31c7158b39..ebca860e113f 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -130,14 +130,75 @@ void inet_get_local_port_range(struct net *net, int *low, int *high) } EXPORT_SYMBOL(inet_get_local_port_range); +static bool inet_use_bhash2_on_bind(const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); + + return addr_type != IPV6_ADDR_ANY && + addr_type != IPV6_ADDR_MAPPED; + } +#endif + return sk->sk_rcv_saddr != htonl(INADDR_ANY); +} + +static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2, + kuid_t sk_uid, bool relax, + bool reuseport_cb_ok, bool reuseport_ok) +{ + int bound_dev_if2; + + if (sk == sk2) + return false; + + bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if); + + if (!sk->sk_bound_dev_if || !bound_dev_if2 || + sk->sk_bound_dev_if == bound_dev_if2) { + if (sk->sk_reuse && sk2->sk_reuse && + sk2->sk_state != TCP_LISTEN) { + if (!relax || (!reuseport_ok && sk->sk_reuseport && + sk2->sk_reuseport && reuseport_cb_ok && + (sk2->sk_state == TCP_TIME_WAIT || + uid_eq(sk_uid, sock_i_uid(sk2))))) + return true; + } else if (!reuseport_ok || !sk->sk_reuseport || + !sk2->sk_reuseport || !reuseport_cb_ok || + (sk2->sk_state != TCP_TIME_WAIT && + !uid_eq(sk_uid, sock_i_uid(sk2)))) { + return true; + } + } + return false; +} + +static bool inet_bhash2_conflict(const struct sock *sk, + const struct inet_bind2_bucket *tb2, + kuid_t sk_uid, + bool relax, bool reuseport_cb_ok, + bool reuseport_ok) +{ + struct sock *sk2; + + sk_for_each_bound_bhash2(sk2, &tb2->owners) { + if (sk->sk_family == AF_INET && ipv6_only_sock(sk2)) + continue; + + if (inet_bind_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok)) + return true; + } + return false; +} + +/* This should be called only when the tb and tb2 hashbuckets' locks are held */ static int inet_csk_bind_conflict(const struct sock *sk, const struct inet_bind_bucket *tb, + const struct inet_bind2_bucket *tb2, /* may be null */ bool relax, bool reuseport_ok) { - struct sock *sk2; bool reuseport_cb_ok; - bool reuse = sk->sk_reuse; - bool reuseport = !!sk->sk_reuseport; struct sock_reuseport *reuseport_cb; kuid_t uid = sock_i_uid((struct sock *)sk); @@ -150,58 +211,88 @@ static int inet_csk_bind_conflict(const struct sock *sk, /* * Unlike other sk lookup places we do not check * for sk_net here, since _all_ the socks listed - * in tb->owners list belong to the same net - the - * one this bucket belongs to. + * in tb->owners and tb2->owners list belong + * to the same net - the one this bucket belongs to. */ - sk_for_each_bound(sk2, &tb->owners) { - int bound_dev_if2; + if (!inet_use_bhash2_on_bind(sk)) { + struct sock *sk2; - if (sk == sk2) - continue; - bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if); - if ((!sk->sk_bound_dev_if || - !bound_dev_if2 || - sk->sk_bound_dev_if == bound_dev_if2)) { - if (reuse && sk2->sk_reuse && - sk2->sk_state != TCP_LISTEN) { - if ((!relax || - (!reuseport_ok && - reuseport && sk2->sk_reuseport && - reuseport_cb_ok && - (sk2->sk_state == TCP_TIME_WAIT || - uid_eq(uid, sock_i_uid(sk2))))) && - inet_rcv_saddr_equal(sk, sk2, true)) - break; - } else if (!reuseport_ok || - !reuseport || !sk2->sk_reuseport || - !reuseport_cb_ok || - (sk2->sk_state != TCP_TIME_WAIT && - !uid_eq(uid, sock_i_uid(sk2)))) { - if (inet_rcv_saddr_equal(sk, sk2, true)) - break; - } - } + sk_for_each_bound(sk2, &tb->owners) + if (inet_bind_conflict(sk, sk2, uid, relax, + reuseport_cb_ok, reuseport_ok) && + inet_rcv_saddr_equal(sk, sk2, true)) + return true; + + return false; + } + + /* Conflicts with an existing IPV6_ADDR_ANY (if ipv6) or INADDR_ANY (if + * ipv4) should have been checked already. We need to do these two + * checks separately because their spinlocks have to be acquired/released + * independently of each other, to prevent possible deadlocks + */ + return tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok, + reuseport_ok); +} + +/* Determine if there is a bind conflict with an existing IPV6_ADDR_ANY (if ipv6) or + * INADDR_ANY (if ipv4) socket. + * + * Caller must hold bhash hashbucket lock with local bh disabled, to protect + * against concurrent binds on the port for addr any + */ +static bool inet_bhash2_addr_any_conflict(const struct sock *sk, int port, int l3mdev, + bool relax, bool reuseport_ok) +{ + kuid_t uid = sock_i_uid((struct sock *)sk); + const struct net *net = sock_net(sk); + struct sock_reuseport *reuseport_cb; + struct inet_bind_hashbucket *head2; + struct inet_bind2_bucket *tb2; + bool reuseport_cb_ok; + + rcu_read_lock(); + reuseport_cb = rcu_dereference(sk->sk_reuseport_cb); + /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */ + reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks); + rcu_read_unlock(); + + head2 = inet_bhash2_addr_any_hashbucket(sk, net, port); + + spin_lock(&head2->lock); + + inet_bind_bucket_for_each(tb2, &head2->chain) + if (inet_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk)) + break; + + if (tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok, + reuseport_ok)) { + spin_unlock(&head2->lock); + return true; } - return sk2 != NULL; + + spin_unlock(&head2->lock); + return false; } /* * Find an open port number for the socket. Returns with the - * inet_bind_hashbucket lock held. + * inet_bind_hashbucket locks held if successful. */ static struct inet_bind_hashbucket * -inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret) +inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret, + struct inet_bind2_bucket **tb2_ret, + struct inet_bind_hashbucket **head2_ret, int *port_ret) { - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; - int port = 0; - struct inet_bind_hashbucket *head; + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + int i, low, high, attempt_half, port, l3mdev; + struct inet_bind_hashbucket *head, *head2; struct net *net = sock_net(sk); - bool relax = false; - int i, low, high, attempt_half; + struct inet_bind2_bucket *tb2; struct inet_bind_bucket *tb; u32 remaining, offset; - int l3mdev; + bool relax = false; l3mdev = inet_sk_bound_l3mdev(sk); ports_exhausted: @@ -239,11 +330,20 @@ other_parity_scan: head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)]; spin_lock_bh(&head->lock); + if (inet_use_bhash2_on_bind(sk)) { + if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, relax, false)) + goto next_port; + } + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); inet_bind_bucket_for_each(tb, &head->chain) - if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev && - tb->port == port) { - if (!inet_csk_bind_conflict(sk, tb, relax, false)) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) { + if (!inet_csk_bind_conflict(sk, tb, tb2, + relax, false)) goto success; + spin_unlock(&head2->lock); goto next_port; } tb = NULL; @@ -272,6 +372,8 @@ next_port: success: *port_ret = port; *tb_ret = tb; + *tb2_ret = tb2; + *head2_ret = head2; return head; } @@ -365,56 +467,97 @@ void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, */ int inet_csk_get_port(struct sock *sk, unsigned short snum) { + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; - int ret = 1, port = snum; - struct inet_bind_hashbucket *head; - struct net *net = sock_net(sk); + bool found_port = false, check_bind_conflict = true; + bool bhash_created = false, bhash2_created = false; + struct inet_bind_hashbucket *head, *head2; + struct inet_bind2_bucket *tb2 = NULL; struct inet_bind_bucket *tb = NULL; - int l3mdev; + bool head2_lock_acquired = false; + int ret = 1, port = snum, l3mdev; + struct net *net = sock_net(sk); l3mdev = inet_sk_bound_l3mdev(sk); if (!port) { - head = inet_csk_find_open_port(sk, &tb, &port); + head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port); if (!head) return ret; + + head2_lock_acquired = true; + + if (tb && tb2) + goto success; + found_port = true; + } else { + head = &hinfo->bhash[inet_bhashfn(net, port, + hinfo->bhash_size)]; + spin_lock_bh(&head->lock); + inet_bind_bucket_for_each(tb, &head->chain) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) + break; + } + + if (!tb) { + tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net, + head, port, l3mdev); if (!tb) - goto tb_not_found; - goto success; + goto fail_unlock; + bhash_created = true; } - head = &hinfo->bhash[inet_bhashfn(net, port, - hinfo->bhash_size)]; - spin_lock_bh(&head->lock); - inet_bind_bucket_for_each(tb, &head->chain) - if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev && - tb->port == port) - goto tb_found; -tb_not_found: - tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, - net, head, port, l3mdev); - if (!tb) - goto fail_unlock; -tb_found: - if (!hlist_empty(&tb->owners)) { - if (sk->sk_reuse == SK_FORCE_REUSE) - goto success; - if ((tb->fastreuse > 0 && reuse) || - sk_reuseport_match(tb, sk)) - goto success; - if (inet_csk_bind_conflict(sk, tb, true, true)) + if (!found_port) { + if (!hlist_empty(&tb->owners)) { + if (sk->sk_reuse == SK_FORCE_REUSE || + (tb->fastreuse > 0 && reuse) || + sk_reuseport_match(tb, sk)) + check_bind_conflict = false; + } + + if (check_bind_conflict && inet_use_bhash2_on_bind(sk)) { + if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, true, true)) + goto fail_unlock; + } + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + head2_lock_acquired = true; + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + } + + if (!tb2) { + tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, + net, head2, port, l3mdev, sk); + if (!tb2) goto fail_unlock; + bhash2_created = true; } + + if (!found_port && check_bind_conflict) { + if (inet_csk_bind_conflict(sk, tb, tb2, true, true)) + goto fail_unlock; + } + success: inet_csk_update_fastreuse(tb, sk); if (!inet_csk(sk)->icsk_bind_hash) - inet_bind_hash(sk, tb, port); + inet_bind_hash(sk, tb, tb2, port); WARN_ON(inet_csk(sk)->icsk_bind_hash != tb); + WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2); ret = 0; fail_unlock: + if (ret) { + if (bhash_created) + inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); + if (bhash2_created) + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, + tb2); + } + if (head2_lock_acquired) + spin_unlock(&head2->lock); spin_unlock_bh(&head->lock); return ret; } @@ -763,14 +906,15 @@ static void reqsk_migrate_reset(struct request_sock *req) /* return true if req was found in the ehash table */ static bool reqsk_queue_unlink(struct request_sock *req) { - struct inet_hashinfo *hashinfo = req_to_sk(req)->sk_prot->h.hashinfo; + struct sock *sk = req_to_sk(req); bool found = false; - if (sk_hashed(req_to_sk(req))) { + if (sk_hashed(sk)) { + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); spinlock_t *lock = inet_ehash_lockp(hashinfo, req->rsk_hash); spin_lock(lock); - found = __sk_nulls_del_node_init_rcu(req_to_sk(req)); + found = __sk_nulls_del_node_init_rcu(sk); spin_unlock(lock); } if (timer_pending(&req->rsk_timer) && del_timer_sync(&req->rsk_timer)) @@ -962,6 +1106,7 @@ struct sock *inet_csk_clone_lock(const struct sock *sk, inet_sk_set_state(newsk, TCP_SYN_RECV); newicsk->icsk_bind_hash = NULL; + newicsk->icsk_bind2_hash = NULL; inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port; inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num; diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index b9d995b5ce24..a0ad34e4f044 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -92,12 +92,79 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket } } +bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, + unsigned short port, int l3mdev) +{ + return net_eq(ib_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev; +} + +static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb, + struct net *net, + struct inet_bind_hashbucket *head, + unsigned short port, int l3mdev, + const struct sock *sk) +{ + write_pnet(&tb->ib_net, net); + tb->l3mdev = l3mdev; + tb->port = port; +#if IS_ENABLED(CONFIG_IPV6) + tb->family = sk->sk_family; + if (sk->sk_family == AF_INET6) + tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr; + else +#endif + tb->rcv_saddr = sk->sk_rcv_saddr; + INIT_HLIST_HEAD(&tb->owners); + hlist_add_head(&tb->node, &head->chain); +} + +struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, + struct net *net, + struct inet_bind_hashbucket *head, + unsigned short port, + int l3mdev, + const struct sock *sk) +{ + struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); + + if (tb) + inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk); + + return tb; +} + +/* Caller must hold hashbucket lock for this tb with local BH disabled */ +void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) +{ + if (hlist_empty(&tb->owners)) { + __hlist_del(&tb->node); + kmem_cache_free(cachep, tb); + } +} + +static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2, + const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family != tb2->family) + return false; + + if (sk->sk_family == AF_INET6) + return ipv6_addr_equal(&tb2->v6_rcv_saddr, + &sk->sk_v6_rcv_saddr); +#endif + return tb2->rcv_saddr == sk->sk_rcv_saddr; +} + void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, - const unsigned short snum) + struct inet_bind2_bucket *tb2, unsigned short port) { - inet_sk(sk)->inet_num = snum; + inet_sk(sk)->inet_num = port; sk_add_bind_node(sk, &tb->owners); inet_csk(sk)->icsk_bind_hash = tb; + sk_add_bind2_node(sk, &tb2->owners); + inet_csk(sk)->icsk_bind2_hash = tb2; } /* @@ -105,11 +172,15 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, */ static void __inet_put_port(struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num, - hashinfo->bhash_size); - struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash]; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_bind_hashbucket *head, *head2; + struct net *net = sock_net(sk); struct inet_bind_bucket *tb; + int bhash; + + bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size); + head = &hashinfo->bhash[bhash]; + head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num); spin_lock(&head->lock); tb = inet_csk(sk)->icsk_bind_hash; @@ -117,6 +188,17 @@ static void __inet_put_port(struct sock *sk) inet_csk(sk)->icsk_bind_hash = NULL; inet_sk(sk)->inet_num = 0; inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + + spin_lock(&head2->lock); + if (inet_csk(sk)->icsk_bind2_hash) { + struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash; + + __sk_del_bind2_node(sk); + inet_csk(sk)->icsk_bind2_hash = NULL; + inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); + } + spin_unlock(&head2->lock); + spin_unlock(&head->lock); } @@ -130,17 +212,26 @@ EXPORT_SYMBOL(inet_put_port); int __inet_inherit_port(const struct sock *sk, struct sock *child) { - struct inet_hashinfo *table = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk); unsigned short port = inet_sk(child)->inet_num; - const int bhash = inet_bhashfn(sock_net(sk), port, - table->bhash_size); - struct inet_bind_hashbucket *head = &table->bhash[bhash]; + struct inet_bind_hashbucket *head, *head2; + bool created_inet_bind_bucket = false; + struct net *net = sock_net(sk); + bool update_fastreuse = false; + struct inet_bind2_bucket *tb2; struct inet_bind_bucket *tb; - int l3mdev; + int bhash, l3mdev; + + bhash = inet_bhashfn(net, port, table->bhash_size); + head = &table->bhash[bhash]; + head2 = inet_bhashfn_portaddr(table, child, net, port); spin_lock(&head->lock); + spin_lock(&head2->lock); tb = inet_csk(sk)->icsk_bind_hash; - if (unlikely(!tb)) { + tb2 = inet_csk(sk)->icsk_bind2_hash; + if (unlikely(!tb || !tb2)) { + spin_unlock(&head2->lock); spin_unlock(&head->lock); return -ENOENT; } @@ -153,25 +244,49 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child) * as that of the child socket. We have to look up or * create a new bind bucket for the child here. */ inet_bind_bucket_for_each(tb, &head->chain) { - if (net_eq(ib_net(tb), sock_net(sk)) && - tb->l3mdev == l3mdev && tb->port == port) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) break; } if (!tb) { tb = inet_bind_bucket_create(table->bind_bucket_cachep, - sock_net(sk), head, port, - l3mdev); + net, head, port, l3mdev); if (!tb) { + spin_unlock(&head2->lock); spin_unlock(&head->lock); return -ENOMEM; } + created_inet_bind_bucket = true; + } + update_fastreuse = true; + + goto bhash2_find; + } else if (!inet_bind2_bucket_addr_match(tb2, child)) { + l3mdev = inet_sk_bound_l3mdev(sk); + +bhash2_find: + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child); + if (!tb2) { + tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep, + net, head2, port, + l3mdev, child); + if (!tb2) + goto error; } - inet_csk_update_fastreuse(tb, child); } - inet_bind_hash(child, tb, port); + if (update_fastreuse) + inet_csk_update_fastreuse(tb, child); + inet_bind_hash(child, tb, tb2, port); + spin_unlock(&head2->lock); spin_unlock(&head->lock); return 0; + +error: + if (created_inet_bind_bucket) + inet_bind_bucket_destroy(table->bind_bucket_cachep, tb); + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + return -ENOMEM; } EXPORT_SYMBOL_GPL(__inet_inherit_port); @@ -275,7 +390,7 @@ static inline struct sock *inet_lookup_run_bpf(struct net *net, struct sock *sk, *reuse_sk; bool no_reuseport; - if (hashinfo != &tcp_hashinfo) + if (hashinfo != net->ipv4.tcp_death_row.hashinfo) return NULL; /* only TCP is supported */ no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport, @@ -518,9 +633,9 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk, */ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - struct hlist_nulls_head *list; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_ehash_bucket *head; + struct hlist_nulls_head *list; spinlock_t *lock; bool ret = true; @@ -590,7 +705,7 @@ static int inet_reuseport_add_sock(struct sock *sk, int __inet_hash(struct sock *sk, struct sock *osk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_listen_hashbucket *ilb2; int err = 0; @@ -636,7 +751,7 @@ EXPORT_SYMBOL_GPL(inet_hash); void inet_unhash(struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); if (sk_unhashed(sk)) return; @@ -675,6 +790,118 @@ void inet_unhash(struct sock *sk) } EXPORT_SYMBOL_GPL(inet_unhash); +static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb, + const struct net *net, unsigned short port, + int l3mdev, const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family != tb->family) + return false; + + if (sk->sk_family == AF_INET6) + return net_eq(ib2_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev && + ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr); + else +#endif + return net_eq(ib2_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr; +} + +bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net, + unsigned short port, int l3mdev, const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct in6_addr addr_any = {}; + + if (sk->sk_family != tb->family) + return false; + + if (sk->sk_family == AF_INET6) + return net_eq(ib2_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev && + ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any); + else +#endif + return net_eq(ib2_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev && tb->rcv_saddr == 0; +} + +/* The socket's bhash2 hashbucket spinlock must be held when this is called */ +struct inet_bind2_bucket * +inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net, + unsigned short port, int l3mdev, const struct sock *sk) +{ + struct inet_bind2_bucket *bhash2 = NULL; + + inet_bind_bucket_for_each(bhash2, &head->chain) + if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk)) + break; + + return bhash2; +} + +struct inet_bind_hashbucket * +inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) +{ + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + u32 hash; +#if IS_ENABLED(CONFIG_IPV6) + struct in6_addr addr_any = {}; + + if (sk->sk_family == AF_INET6) + hash = ipv6_portaddr_hash(net, &addr_any, port); + else +#endif + hash = ipv4_portaddr_hash(net, 0, port); + + return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)]; +} + +int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk) +{ + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_bind2_bucket *tb2, *new_tb2; + int l3mdev = inet_sk_bound_l3mdev(sk); + struct inet_bind_hashbucket *head2; + int port = inet_sk(sk)->inet_num; + struct net *net = sock_net(sk); + + /* Allocate a bind2 bucket ahead of time to avoid permanently putting + * the bhash2 table in an inconsistent state if a new tb2 bucket + * allocation fails. + */ + new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC); + if (!new_tb2) + return -ENOMEM; + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + + if (prev_saddr) { + spin_lock_bh(&prev_saddr->lock); + __sk_del_bind2_node(sk); + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, + inet_csk(sk)->icsk_bind2_hash); + spin_unlock_bh(&prev_saddr->lock); + } + + spin_lock_bh(&head2->lock); + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + if (!tb2) { + tb2 = new_tb2; + inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk); + } + sk_add_bind2_node(sk, &tb2->owners); + inet_csk(sk)->icsk_bind2_hash = tb2; + spin_unlock_bh(&head2->lock); + + if (tb2 != new_tb2) + kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2); + + return 0; +} +EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr); + /* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm * Note that we use 32bit integers (vs RFC 'short integers') * because 2^16 is not a multiple of num_ephemeral and this @@ -694,11 +921,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *, __u16, struct inet_timewait_sock **)) { struct inet_hashinfo *hinfo = death_row->hashinfo; + struct inet_bind_hashbucket *head, *head2; struct inet_timewait_sock *tw = NULL; - struct inet_bind_hashbucket *head; int port = inet_sk(sk)->inet_num; struct net *net = sock_net(sk); + struct inet_bind2_bucket *tb2; struct inet_bind_bucket *tb; + bool tb_created = false; u32 remaining, offset; int ret, i, low, high; int l3mdev; @@ -729,8 +958,8 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, if (likely(remaining > 1)) remaining &= ~1U; - net_get_random_once(table_perturb, - INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb)); + get_random_sleepable_once(table_perturb, + INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb)); index = port_offset & (INET_TABLE_PERTURB_SIZE - 1); offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32); @@ -755,8 +984,7 @@ other_parity_scan: * the established check is already unique enough. */ inet_bind_bucket_for_each(tb, &head->chain) { - if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev && - tb->port == port) { + if (inet_bind_bucket_match(tb, net, port, l3mdev)) { if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) goto next_port; @@ -774,6 +1002,7 @@ other_parity_scan: spin_unlock_bh(&head->lock); return -ENOMEM; } + tb_created = true; tb->fastreuse = -1; tb->fastreuseport = -1; goto ok; @@ -789,6 +1018,20 @@ next_port: return -EADDRNOTAVAIL; ok: + /* Find the corresponding tb2 bucket since we need to + * add the socket to the bhash2 table as well + */ + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + if (!tb2) { + tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net, + head2, port, l3mdev, sk); + if (!tb2) + goto error; + } + /* Here we want to add a little bit of randomness to the next source * port that will be chosen. We use a max() with a random here so that * on low contention the randomness is maximal and on high contention @@ -798,7 +1041,10 @@ ok: WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); /* Head lock still held and bh's disabled */ - inet_bind_hash(sk, tb, port); + inet_bind_hash(sk, tb, tb2, port); + + spin_unlock(&head2->lock); + if (sk_unhashed(sk)) { inet_sk(sk)->inet_sport = htons(port); inet_ehash_nolisten(sk, (struct sock *)tw, NULL); @@ -810,6 +1056,13 @@ ok: inet_twsk_deschedule_put(tw); local_bh_enable(); return 0; + +error: + spin_unlock(&head2->lock); + if (tb_created) + inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); + spin_unlock_bh(&head->lock); + return -ENOMEM; } /* @@ -902,3 +1155,50 @@ int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo) return 0; } EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc); + +struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo, + unsigned int ehash_entries) +{ + struct inet_hashinfo *new_hashinfo; + int i; + + new_hashinfo = kmemdup(hashinfo, sizeof(*hashinfo), GFP_KERNEL); + if (!new_hashinfo) + goto err; + + new_hashinfo->ehash = vmalloc_huge(ehash_entries * sizeof(struct inet_ehash_bucket), + GFP_KERNEL_ACCOUNT); + if (!new_hashinfo->ehash) + goto free_hashinfo; + + new_hashinfo->ehash_mask = ehash_entries - 1; + + if (inet_ehash_locks_alloc(new_hashinfo)) + goto free_ehash; + + for (i = 0; i < ehash_entries; i++) + INIT_HLIST_NULLS_HEAD(&new_hashinfo->ehash[i].chain, i); + + new_hashinfo->pernet = true; + + return new_hashinfo; + +free_ehash: + vfree(new_hashinfo->ehash); +free_hashinfo: + kfree(new_hashinfo); +err: + return NULL; +} +EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_alloc); + +void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) +{ + if (!hashinfo->pernet) + return; + + inet_ehash_locks_free(hashinfo); + vfree(hashinfo->ehash); + kfree(hashinfo); +} +EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_free); diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c index 47ccc343c9fb..71d3bb0abf6c 100644 --- a/net/ipv4/inet_timewait_sock.c +++ b/net/ipv4/inet_timewait_sock.c @@ -59,9 +59,7 @@ static void inet_twsk_kill(struct inet_timewait_sock *tw) inet_twsk_bind_unhash(tw, hashinfo); spin_unlock(&bhead->lock); - if (refcount_dec_and_test(&tw->tw_dr->tw_refcount)) - kfree(tw->tw_dr); - + refcount_dec(&tw->tw_dr->tw_refcount); inet_twsk_put(tw); } diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c index 04e2034f2f8e..1ae83ad629b2 100644 --- a/net/ipv4/ip_output.c +++ b/net/ipv4/ip_output.c @@ -1043,7 +1043,7 @@ static int __ip_append_data(struct sock *sk, paged = true; zc = true; } else { - uarg->zerocopy = 0; + uarg_to_msgzc(uarg)->zerocopy = 0; skb_zcopy_set(skb, uarg, &extra_uref); } } @@ -1109,10 +1109,7 @@ alloc_new_skb: (fraglen + alloc_extra < SKB_MAX_ALLOC || !(rt->dst.dev->features & NETIF_F_SG))) alloclen = fraglen; - else if (!zc) { - alloclen = min_t(int, fraglen, MAX_HEADER); - pagedlen = fraglen - alloclen; - } else { + else { alloclen = fragheaderlen + transhdrlen; pagedlen = datalen - transhdrlen; } diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c index e49a61a053a6..6e19cad154f5 100644 --- a/net/ipv4/ip_sockglue.c +++ b/net/ipv4/ip_sockglue.c @@ -888,8 +888,8 @@ static int compat_ip_mcast_join_leave(struct sock *sk, int optname, DEFINE_STATIC_KEY_FALSE(ip4_min_ttl); -static int do_ip_setsockopt(struct sock *sk, int level, int optname, - sockptr_t optval, unsigned int optlen) +int do_ip_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) { struct inet_sock *inet = inet_sk(sk); struct net *net = sock_net(sk); @@ -944,7 +944,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, err = 0; if (needs_rtnl) rtnl_lock(); - lock_sock(sk); + sockopt_lock_sock(sk); switch (optname) { case IP_OPTIONS: @@ -1333,14 +1333,14 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, case IP_IPSEC_POLICY: case IP_XFRM_POLICY: err = -EPERM; - if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) break; err = xfrm_user_policy(sk, optname, optval, optlen); break; case IP_TRANSPARENT: - if (!!val && !ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && - !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { + if (!!val && !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && + !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { err = -EPERM; break; } @@ -1368,13 +1368,13 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, err = -ENOPROTOOPT; break; } - release_sock(sk); + sockopt_release_sock(sk); if (needs_rtnl) rtnl_unlock(); return err; e_inval: - release_sock(sk); + sockopt_release_sock(sk); if (needs_rtnl) rtnl_unlock(); return -EINVAL; @@ -1462,37 +1462,37 @@ static bool getsockopt_needs_rtnl(int optname) return false; } -static int ip_get_mcast_msfilter(struct sock *sk, void __user *optval, - int __user *optlen, int len) +static int ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) { const int size0 = offsetof(struct group_filter, gf_slist_flex); - struct group_filter __user *p = optval; struct group_filter gsf; - int num; + int num, gsf_size; int err; if (len < size0) return -EINVAL; - if (copy_from_user(&gsf, p, size0)) + if (copy_from_sockptr(&gsf, optval, size0)) return -EFAULT; num = gsf.gf_numsrc; - err = ip_mc_gsfget(sk, &gsf, p->gf_slist_flex); + err = ip_mc_gsfget(sk, &gsf, optval, + offsetof(struct group_filter, gf_slist_flex)); if (err) return err; if (gsf.gf_numsrc < num) num = gsf.gf_numsrc; - if (put_user(GROUP_FILTER_SIZE(num), optlen) || - copy_to_user(p, &gsf, size0)) + gsf_size = GROUP_FILTER_SIZE(num); + if (copy_to_sockptr(optlen, &gsf_size, sizeof(int)) || + copy_to_sockptr(optval, &gsf, size0)) return -EFAULT; return 0; } -static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval, - int __user *optlen, int len) +static int compat_ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) { const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); - struct compat_group_filter __user *p = optval; struct compat_group_filter gf32; struct group_filter gf; int num; @@ -1500,7 +1500,7 @@ static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval, if (len < size0) return -EINVAL; - if (copy_from_user(&gf32, p, size0)) + if (copy_from_sockptr(&gf32, optval, size0)) return -EFAULT; gf.gf_interface = gf32.gf_interface; @@ -1508,21 +1508,24 @@ static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval, num = gf.gf_numsrc = gf32.gf_numsrc; gf.gf_group = gf32.gf_group; - err = ip_mc_gsfget(sk, &gf, p->gf_slist_flex); + err = ip_mc_gsfget(sk, &gf, optval, + offsetof(struct compat_group_filter, gf_slist_flex)); if (err) return err; if (gf.gf_numsrc < num) num = gf.gf_numsrc; len = GROUP_FILTER_SIZE(num) - (sizeof(gf) - sizeof(gf32)); - if (put_user(len, optlen) || - put_user(gf.gf_fmode, &p->gf_fmode) || - put_user(gf.gf_numsrc, &p->gf_numsrc)) + if (copy_to_sockptr(optlen, &len, sizeof(int)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode), + &gf.gf_fmode, sizeof(gf.gf_fmode)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc), + &gf.gf_numsrc, sizeof(gf.gf_numsrc))) return -EFAULT; return 0; } -static int do_ip_getsockopt(struct sock *sk, int level, int optname, - char __user *optval, int __user *optlen) +int do_ip_getsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, sockptr_t optlen) { struct inet_sock *inet = inet_sk(sk); bool needs_rtnl = getsockopt_needs_rtnl(optname); @@ -1535,14 +1538,14 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, if (ip_mroute_opt(optname)) return ip_mroute_getsockopt(sk, optname, optval, optlen); - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; if (len < 0) return -EINVAL; if (needs_rtnl) rtnl_lock(); - lock_sock(sk); + sockopt_lock_sock(sk); switch (optname) { case IP_OPTIONS: @@ -1558,17 +1561,19 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, memcpy(optbuf, &inet_opt->opt, sizeof(struct ip_options) + inet_opt->opt.optlen); - release_sock(sk); + sockopt_release_sock(sk); - if (opt->optlen == 0) - return put_user(0, optlen); + if (opt->optlen == 0) { + len = 0; + return copy_to_sockptr(optlen, &len, sizeof(int)); + } ip_options_undo(opt); len = min_t(unsigned int, len, opt->optlen); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, opt->__data, len)) + if (copy_to_sockptr(optval, opt->__data, len)) return -EFAULT; return 0; } @@ -1632,7 +1637,7 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, dst_release(dst); } if (!val) { - release_sock(sk); + sockopt_release_sock(sk); return -ENOTCONN; } break; @@ -1657,11 +1662,11 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, struct in_addr addr; len = min_t(unsigned int, len, sizeof(struct in_addr)); addr.s_addr = inet->mc_addr; - release_sock(sk); + sockopt_release_sock(sk); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &addr, len)) + if (copy_to_sockptr(optval, &addr, len)) return -EFAULT; return 0; } @@ -1673,12 +1678,11 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, err = -EINVAL; goto out; } - if (copy_from_user(&msf, optval, IP_MSFILTER_SIZE(0))) { + if (copy_from_sockptr(&msf, optval, IP_MSFILTER_SIZE(0))) { err = -EFAULT; goto out; } - err = ip_mc_msfget(sk, &msf, - (struct ip_msfilter __user *)optval, optlen); + err = ip_mc_msfget(sk, &msf, optval, optlen); goto out; } case MCAST_MSFILTER: @@ -1695,13 +1699,18 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, { struct msghdr msg; - release_sock(sk); + sockopt_release_sock(sk); if (sk->sk_type != SOCK_STREAM) return -ENOPROTOOPT; - msg.msg_control_is_user = true; - msg.msg_control_user = optval; + if (optval.is_kernel) { + msg.msg_control_is_user = false; + msg.msg_control = optval.kernel; + } else { + msg.msg_control_is_user = true; + msg.msg_control_user = optval.user; + } msg.msg_controllen = len; msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0; @@ -1722,7 +1731,7 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos); } len -= msg.msg_controllen; - return put_user(len, optlen); + return copy_to_sockptr(optlen, &len, sizeof(int)); } case IP_FREEBIND: val = inet->freebind; @@ -1734,29 +1743,29 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, val = inet->min_ttl; break; default: - release_sock(sk); + sockopt_release_sock(sk); return -ENOPROTOOPT; } - release_sock(sk); + sockopt_release_sock(sk); if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) { unsigned char ucval = (unsigned char)val; len = 1; - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &ucval, 1)) + if (copy_to_sockptr(optval, &ucval, 1)) return -EFAULT; } else { len = min_t(unsigned int, sizeof(int), len); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &val, len)) + if (copy_to_sockptr(optval, &val, len)) return -EFAULT; } return 0; out: - release_sock(sk); + sockopt_release_sock(sk); if (needs_rtnl) rtnl_unlock(); return err; @@ -1767,7 +1776,8 @@ int ip_getsockopt(struct sock *sk, int level, { int err; - err = do_ip_getsockopt(sk, level, optname, optval, optlen); + err = do_ip_getsockopt(sk, level, optname, + USER_SOCKPTR(optval), USER_SOCKPTR(optlen)); #if IS_ENABLED(CONFIG_BPFILTER_UMH) if (optname >= BPFILTER_IPT_SO_GET_INFO && diff --git a/net/ipv4/ip_tunnel_core.c b/net/ipv4/ip_tunnel_core.c index cc1caab4a654..92c02c886fe7 100644 --- a/net/ipv4/ip_tunnel_core.c +++ b/net/ipv4/ip_tunnel_core.c @@ -1079,3 +1079,70 @@ EXPORT_SYMBOL(ip_tunnel_parse_protocol); const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tunnel_parse_protocol }; EXPORT_SYMBOL(ip_tunnel_header_ops); + +/* This function returns true when ENCAP attributes are present in the nl msg */ +bool ip_tunnel_netlink_encap_parms(struct nlattr *data[], + struct ip_tunnel_encap *encap) +{ + bool ret = false; + + memset(encap, 0, sizeof(*encap)); + + if (!data) + return ret; + + if (data[IFLA_IPTUN_ENCAP_TYPE]) { + ret = true; + encap->type = nla_get_u16(data[IFLA_IPTUN_ENCAP_TYPE]); + } + + if (data[IFLA_IPTUN_ENCAP_FLAGS]) { + ret = true; + encap->flags = nla_get_u16(data[IFLA_IPTUN_ENCAP_FLAGS]); + } + + if (data[IFLA_IPTUN_ENCAP_SPORT]) { + ret = true; + encap->sport = nla_get_be16(data[IFLA_IPTUN_ENCAP_SPORT]); + } + + if (data[IFLA_IPTUN_ENCAP_DPORT]) { + ret = true; + encap->dport = nla_get_be16(data[IFLA_IPTUN_ENCAP_DPORT]); + } + + return ret; +} +EXPORT_SYMBOL_GPL(ip_tunnel_netlink_encap_parms); + +void ip_tunnel_netlink_parms(struct nlattr *data[], + struct ip_tunnel_parm *parms) +{ + if (data[IFLA_IPTUN_LINK]) + parms->link = nla_get_u32(data[IFLA_IPTUN_LINK]); + + if (data[IFLA_IPTUN_LOCAL]) + parms->iph.saddr = nla_get_be32(data[IFLA_IPTUN_LOCAL]); + + if (data[IFLA_IPTUN_REMOTE]) + parms->iph.daddr = nla_get_be32(data[IFLA_IPTUN_REMOTE]); + + if (data[IFLA_IPTUN_TTL]) { + parms->iph.ttl = nla_get_u8(data[IFLA_IPTUN_TTL]); + if (parms->iph.ttl) + parms->iph.frag_off = htons(IP_DF); + } + + if (data[IFLA_IPTUN_TOS]) + parms->iph.tos = nla_get_u8(data[IFLA_IPTUN_TOS]); + + if (!data[IFLA_IPTUN_PMTUDISC] || nla_get_u8(data[IFLA_IPTUN_PMTUDISC])) + parms->iph.frag_off = htons(IP_DF); + + if (data[IFLA_IPTUN_FLAGS]) + parms->i_flags = nla_get_be16(data[IFLA_IPTUN_FLAGS]); + + if (data[IFLA_IPTUN_PROTO]) + parms->iph.protocol = nla_get_u8(data[IFLA_IPTUN_PROTO]); +} +EXPORT_SYMBOL_GPL(ip_tunnel_netlink_parms); diff --git a/net/ipv4/ipcomp.c b/net/ipv4/ipcomp.c index 366094c1ce6c..5a4fb2539b08 100644 --- a/net/ipv4/ipcomp.c +++ b/net/ipv4/ipcomp.c @@ -117,7 +117,8 @@ out: return err; } -static int ipcomp4_init_state(struct xfrm_state *x) +static int ipcomp4_init_state(struct xfrm_state *x, + struct netlink_ext_ack *extack) { int err = -EINVAL; @@ -129,17 +130,20 @@ static int ipcomp4_init_state(struct xfrm_state *x) x->props.header_len += sizeof(struct iphdr); break; default: + NL_SET_ERR_MSG(extack, "Unsupported XFRM mode for IPcomp"); goto out; } - err = ipcomp_init_state(x); + err = ipcomp_init_state(x, extack); if (err) goto out; if (x->props.mode == XFRM_MODE_TUNNEL) { err = ipcomp_tunnel_attach(x); - if (err) + if (err) { + NL_SET_ERR_MSG(extack, "Kernel error: failed to initialize the associated state"); goto out; + } } err = 0; diff --git a/net/ipv4/ipip.c b/net/ipv4/ipip.c index 123ea63a04cb..180f9daf5bec 100644 --- a/net/ipv4/ipip.c +++ b/net/ipv4/ipip.c @@ -417,29 +417,7 @@ static void ipip_netlink_parms(struct nlattr *data[], if (!data) return; - if (data[IFLA_IPTUN_LINK]) - parms->link = nla_get_u32(data[IFLA_IPTUN_LINK]); - - if (data[IFLA_IPTUN_LOCAL]) - parms->iph.saddr = nla_get_in_addr(data[IFLA_IPTUN_LOCAL]); - - if (data[IFLA_IPTUN_REMOTE]) - parms->iph.daddr = nla_get_in_addr(data[IFLA_IPTUN_REMOTE]); - - if (data[IFLA_IPTUN_TTL]) { - parms->iph.ttl = nla_get_u8(data[IFLA_IPTUN_TTL]); - if (parms->iph.ttl) - parms->iph.frag_off = htons(IP_DF); - } - - if (data[IFLA_IPTUN_TOS]) - parms->iph.tos = nla_get_u8(data[IFLA_IPTUN_TOS]); - - if (data[IFLA_IPTUN_PROTO]) - parms->iph.protocol = nla_get_u8(data[IFLA_IPTUN_PROTO]); - - if (!data[IFLA_IPTUN_PMTUDISC] || nla_get_u8(data[IFLA_IPTUN_PMTUDISC])) - parms->iph.frag_off = htons(IP_DF); + ip_tunnel_netlink_parms(data, parms); if (data[IFLA_IPTUN_COLLECT_METADATA]) *collect_md = true; @@ -448,40 +426,6 @@ static void ipip_netlink_parms(struct nlattr *data[], *fwmark = nla_get_u32(data[IFLA_IPTUN_FWMARK]); } -/* This function returns true when ENCAP attributes are present in the nl msg */ -static bool ipip_netlink_encap_parms(struct nlattr *data[], - struct ip_tunnel_encap *ipencap) -{ - bool ret = false; - - memset(ipencap, 0, sizeof(*ipencap)); - - if (!data) - return ret; - - if (data[IFLA_IPTUN_ENCAP_TYPE]) { - ret = true; - ipencap->type = nla_get_u16(data[IFLA_IPTUN_ENCAP_TYPE]); - } - - if (data[IFLA_IPTUN_ENCAP_FLAGS]) { - ret = true; - ipencap->flags = nla_get_u16(data[IFLA_IPTUN_ENCAP_FLAGS]); - } - - if (data[IFLA_IPTUN_ENCAP_SPORT]) { - ret = true; - ipencap->sport = nla_get_be16(data[IFLA_IPTUN_ENCAP_SPORT]); - } - - if (data[IFLA_IPTUN_ENCAP_DPORT]) { - ret = true; - ipencap->dport = nla_get_be16(data[IFLA_IPTUN_ENCAP_DPORT]); - } - - return ret; -} - static int ipip_newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[], struct netlink_ext_ack *extack) @@ -491,7 +435,7 @@ static int ipip_newlink(struct net *src_net, struct net_device *dev, struct ip_tunnel_encap ipencap; __u32 fwmark = 0; - if (ipip_netlink_encap_parms(data, &ipencap)) { + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { int err = ip_tunnel_encap_setup(t, &ipencap); if (err < 0) @@ -512,7 +456,7 @@ static int ipip_changelink(struct net_device *dev, struct nlattr *tb[], bool collect_md; __u32 fwmark = t->fwmark; - if (ipip_netlink_encap_parms(data, &ipencap)) { + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { int err = ip_tunnel_encap_setup(t, &ipencap); if (err < 0) diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index e11d6b0b62b7..e04544ac4b45 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -1548,7 +1548,8 @@ out: } /* Getsock opt support for the multicast routing system. */ -int ip_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int __user *optlen) +int ip_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval, + sockptr_t optlen) { int olr; int val; @@ -1579,14 +1580,14 @@ int ip_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int return -ENOPROTOOPT; } - if (get_user(olr, optlen)) + if (copy_from_sockptr(&olr, optlen, sizeof(int))) return -EFAULT; olr = min_t(unsigned int, olr, sizeof(int)); if (olr < 0) return -EINVAL; - if (put_user(olr, optlen)) + if (copy_to_sockptr(optlen, &olr, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &val, olr)) + if (copy_to_sockptr(optval, &val, olr)) return -EFAULT; return 0; } diff --git a/net/ipv4/netfilter/ipt_rpfilter.c b/net/ipv4/netfilter/ipt_rpfilter.c index 8cd3224d913e..8183bbcabb4a 100644 --- a/net/ipv4/netfilter/ipt_rpfilter.c +++ b/net/ipv4/netfilter/ipt_rpfilter.c @@ -33,7 +33,6 @@ static bool rpfilter_lookup_reverse(struct net *net, struct flowi4 *fl4, const struct net_device *dev, u8 flags) { struct fib_result res; - int ret __maybe_unused; if (fib_lookup(net, fl4, &res, FIB_LOOKUP_IGNORE_LINKSTATE)) return false; diff --git a/net/ipv4/netfilter/nf_nat_h323.c b/net/ipv4/netfilter/nf_nat_h323.c index a334f0dcc2d0..faee20af4856 100644 --- a/net/ipv4/netfilter/nf_nat_h323.c +++ b/net/ipv4/netfilter/nf_nat_h323.c @@ -291,20 +291,7 @@ static int nat_t120(struct sk_buff *skb, struct nf_conn *ct, exp->expectfn = nf_nat_follow_master; exp->dir = !dir; - /* Try to get same port: if not, try to change it. */ - for (; nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, nated_port); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_h323: out of TCP ports\n"); return 0; @@ -347,20 +334,7 @@ static int nat_h245(struct sk_buff *skb, struct nf_conn *ct, if (info->sig_port[dir] == port) nated_port = ntohs(info->sig_port[!dir]); - /* Try to get same port: if not, try to change it. */ - for (; nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, nated_port); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_q931: out of TCP ports\n"); return 0; @@ -439,20 +413,7 @@ static int nat_q931(struct sk_buff *skb, struct nf_conn *ct, if (info->sig_port[dir] == port) nated_port = ntohs(info->sig_port[!dir]); - /* Try to get same port: if not, try to change it. */ - for (; nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, nated_port); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_ras: out of TCP ports\n"); return 0; @@ -532,20 +493,7 @@ static int nat_callforwarding(struct sk_buff *skb, struct nf_conn *ct, exp->expectfn = ip_nat_callforwarding_expect; exp->dir = !dir; - /* Try to get same port: if not, try to change it. */ - for (nated_port = ntohs(port); nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, ntohs(port)); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_q931: out of TCP ports\n"); return 0; diff --git a/net/ipv4/netfilter/nf_socket_ipv4.c b/net/ipv4/netfilter/nf_socket_ipv4.c index 2d42e4c35a20..a1350fc25838 100644 --- a/net/ipv4/netfilter/nf_socket_ipv4.c +++ b/net/ipv4/netfilter/nf_socket_ipv4.c @@ -71,8 +71,8 @@ nf_socket_get_sock_v4(struct net *net, struct sk_buff *skb, const int doff, { switch (protocol) { case IPPROTO_TCP: - return inet_lookup(net, &tcp_hashinfo, skb, doff, - saddr, sport, daddr, dport, + return inet_lookup(net, net->ipv4.tcp_death_row.hashinfo, + skb, doff, saddr, sport, daddr, dport, in->ifindex); case IPPROTO_UDP: return udp4_lib_lookup(net, saddr, sport, daddr, dport, diff --git a/net/ipv4/netfilter/nf_tproxy_ipv4.c b/net/ipv4/netfilter/nf_tproxy_ipv4.c index b2bae0b0e42a..b22b2c745c76 100644 --- a/net/ipv4/netfilter/nf_tproxy_ipv4.c +++ b/net/ipv4/netfilter/nf_tproxy_ipv4.c @@ -79,6 +79,7 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, const struct net_device *in, const enum nf_tproxy_lookup_t lookup_type) { + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; struct sock *sk; switch (protocol) { @@ -92,12 +93,10 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, switch (lookup_type) { case NF_TPROXY_LOOKUP_LISTENER: - sk = inet_lookup_listener(net, &tcp_hashinfo, skb, - ip_hdrlen(skb) + - __tcp_hdrlen(hp), - saddr, sport, - daddr, dport, - in->ifindex, 0); + sk = inet_lookup_listener(net, hinfo, skb, + ip_hdrlen(skb) + __tcp_hdrlen(hp), + saddr, sport, daddr, dport, + in->ifindex, 0); if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) sk = NULL; @@ -108,9 +107,8 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, */ break; case NF_TPROXY_LOOKUP_ESTABLISHED: - sk = inet_lookup_established(net, &tcp_hashinfo, - saddr, sport, daddr, dport, - in->ifindex); + sk = inet_lookup_established(net, hinfo, saddr, sport, + daddr, dport, in->ifindex); break; default: BUG(); diff --git a/net/ipv4/netfilter/nft_fib_ipv4.c b/net/ipv4/netfilter/nft_fib_ipv4.c index b75cac69bd7e..7ade04ff972d 100644 --- a/net/ipv4/netfilter/nft_fib_ipv4.c +++ b/net/ipv4/netfilter/nft_fib_ipv4.c @@ -83,6 +83,9 @@ void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, else oif = NULL; + if (priv->flags & NFTA_FIB_F_IIF) + fl4.flowi4_oif = l3mdev_master_ifindex_rcu(oif); + if (nft_hook(pkt) == NF_INET_PRE_ROUTING && nft_fib_is_loopback(pkt->skb, nft_in(pkt))) { nft_fib_store_result(dest, priv, nft_in(pkt)); diff --git a/net/ipv4/ping.c b/net/ipv4/ping.c index b83c2bd9d722..517042caf6dc 100644 --- a/net/ipv4/ping.c +++ b/net/ipv4/ping.c @@ -33,6 +33,7 @@ #include <linux/skbuff.h> #include <linux/proc_fs.h> #include <linux/export.h> +#include <linux/bpf-cgroup.h> #include <net/sock.h> #include <net/ping.h> #include <net/udp.h> @@ -295,6 +296,19 @@ void ping_close(struct sock *sk, long timeout) } EXPORT_SYMBOL_GPL(ping_close); +static int ping_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from __ip4_datagram_connect() and + * intended to prevent BPF program called below from accessing bytes + * that are out of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr); +} + /* Checks the bind address and possibly modifies sk->sk_bound_dev_if. */ static int ping_check_bind_addr(struct sock *sk, struct inet_sock *isk, struct sockaddr *uaddr, int addr_len) @@ -1009,6 +1023,7 @@ struct proto ping_prot = { .owner = THIS_MODULE, .init = ping_init_sock, .close = ping_close, + .pre_connect = ping_pre_connect, .connect = ip4_datagram_connect, .disconnect = __udp_disconnect, .setsockopt = ip_setsockopt, diff --git a/net/ipv4/proc.c b/net/ipv4/proc.c index 0088a4c64d77..5386f460bd20 100644 --- a/net/ipv4/proc.c +++ b/net/ipv4/proc.c @@ -59,7 +59,7 @@ static int sockstat_seq_show(struct seq_file *seq, void *v) socket_seq_show(seq); seq_printf(seq, "TCP: inuse %d orphan %d tw %d alloc %d mem %ld\n", sock_prot_inuse_get(net, &tcp_prot), orphans, - refcount_read(&net->ipv4.tcp_death_row->tw_refcount) - 1, + refcount_read(&net->ipv4.tcp_death_row.tw_refcount) - 1, sockets, proto_memory_allocated(&tcp_prot)); seq_printf(seq, "UDP: inuse %d mem %ld\n", sock_prot_inuse_get(net, &udp_prot), diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index 5490c285668b..9b8a6db7a66b 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -39,6 +39,7 @@ static u32 u32_max_div_HZ = UINT_MAX / HZ; static int one_day_secs = 24 * 3600; static u32 fib_multipath_hash_fields_all_mask __maybe_unused = FIB_MULTIPATH_HASH_FIELD_ALL_MASK; +static unsigned int tcp_child_ehash_entries_max = 16 * 1024 * 1024; /* obsolete */ static int sysctl_tcp_low_latency __read_mostly; @@ -382,6 +383,29 @@ static int proc_tcp_available_ulp(struct ctl_table *ctl, return ret; } +static int proc_tcp_ehash_entries(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_tcp_child_ehash_entries); + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; + int tcp_ehash_entries; + struct ctl_table tbl; + + tcp_ehash_entries = hinfo->ehash_mask + 1; + + /* A negative number indicates that the child netns + * shares the global ehash. + */ + if (!net_eq(net, &init_net) && !hinfo->pernet) + tcp_ehash_entries *= -1; + + tbl.data = &tcp_ehash_entries; + tbl.maxlen = sizeof(int); + + return proc_dointvec(&tbl, write, buffer, lenp, ppos); +} + #ifdef CONFIG_IP_ROUTE_MULTIPATH static int proc_fib_multipath_hash_policy(struct ctl_table *table, int write, void *buffer, size_t *lenp, @@ -530,10 +554,9 @@ static struct ctl_table ipv4_table[] = { }; static struct ctl_table ipv4_net_table[] = { - /* tcp_max_tw_buckets must be first in this table. */ { .procname = "tcp_max_tw_buckets", -/* .data = &init_net.ipv4.tcp_death_row.sysctl_max_tw_buckets, */ + .data = &init_net.ipv4.tcp_death_row.sysctl_max_tw_buckets, .maxlen = sizeof(int), .mode = 0644, .proc_handler = proc_dointvec @@ -1322,6 +1345,21 @@ static struct ctl_table ipv4_net_table[] = { .extra2 = SYSCTL_ONE, }, { + .procname = "tcp_ehash_entries", + .data = &init_net.ipv4.sysctl_tcp_child_ehash_entries, + .mode = 0444, + .proc_handler = proc_tcp_ehash_entries, + }, + { + .procname = "tcp_child_ehash_entries", + .data = &init_net.ipv4.sysctl_tcp_child_ehash_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_child_ehash_entries_max, + }, + { .procname = "udp_rmem_min", .data = &init_net.ipv4.sysctl_udp_rmem_min, .maxlen = sizeof(init_net.ipv4.sysctl_udp_rmem_min), @@ -1361,8 +1399,7 @@ static __net_init int ipv4_sysctl_init_net(struct net *net) if (!table) goto err_alloc; - /* skip first entry (sysctl_max_tw_buckets) */ - for (i = 1; i < ARRAY_SIZE(ipv4_net_table) - 1; i++) { + for (i = 0; i < ARRAY_SIZE(ipv4_net_table) - 1; i++) { if (table[i].data) { /* Update the variables to point into * the current struct net @@ -1377,8 +1414,6 @@ static __net_init int ipv4_sysctl_init_net(struct net *net) } } - table[0].data = &net->ipv4.tcp_death_row->sysctl_max_tw_buckets; - net->ipv4.ipv4_hdr = register_net_sysctl(net, "net/ipv4", table); if (!net->ipv4.ipv4_hdr) goto err_reg; diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index e373dde1f46f..0c51abeee172 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -1162,9 +1162,8 @@ void tcp_free_fastopen_req(struct tcp_sock *tp) } } -static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, - int *copied, size_t size, - struct ubuf_info *uarg) +int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, int *copied, + size_t size, struct ubuf_info *uarg) { struct tcp_sock *tp = tcp_sk(sk); struct inet_sock *inet = inet_sk(sk); @@ -1239,7 +1238,7 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) } zc = sk->sk_route_caps & NETIF_F_SG; if (!zc) - uarg->zerocopy = 0; + uarg_to_msgzc(uarg)->zerocopy = 0; } } @@ -3137,6 +3136,8 @@ int tcp_disconnect(struct sock *sk, int flags) tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; tcp_snd_cwnd_set(tp, TCP_INIT_CWND); tp->snd_cwnd_cnt = 0; + tp->is_cwnd_limited = 0; + tp->max_packets_out = 0; tp->window_clamp = 0; tp->delivered = 0; tp->delivered_ce = 0; @@ -3208,7 +3209,7 @@ EXPORT_SYMBOL(tcp_disconnect); static inline bool tcp_can_repair_sock(const struct sock *sk) { - return ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) && + return sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) && (sk->sk_state != TCP_LISTEN); } @@ -3485,8 +3486,8 @@ int tcp_set_window_clamp(struct sock *sk, int val) /* * Socket option code for TCP. */ -static int do_tcp_setsockopt(struct sock *sk, int level, int optname, - sockptr_t optval, unsigned int optlen) +int do_tcp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) { struct tcp_sock *tp = tcp_sk(sk); struct inet_connection_sock *icsk = inet_csk(sk); @@ -3508,11 +3509,11 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, return -EFAULT; name[val] = 0; - lock_sock(sk); - err = tcp_set_congestion_control(sk, name, true, - ns_capable(sock_net(sk)->user_ns, - CAP_NET_ADMIN)); - release_sock(sk); + sockopt_lock_sock(sk); + err = tcp_set_congestion_control(sk, name, !has_current_bpf_ctx(), + sockopt_ns_capable(sock_net(sk)->user_ns, + CAP_NET_ADMIN)); + sockopt_release_sock(sk); return err; } case TCP_ULP: { @@ -3528,9 +3529,9 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, return -EFAULT; name[val] = 0; - lock_sock(sk); + sockopt_lock_sock(sk); err = tcp_set_ulp(sk, name); - release_sock(sk); + sockopt_release_sock(sk); return err; } case TCP_FASTOPEN_KEY: { @@ -3563,7 +3564,7 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, if (copy_from_sockptr(&val, optval, sizeof(val))) return -EFAULT; - lock_sock(sk); + sockopt_lock_sock(sk); switch (optname) { case TCP_MAXSEG: @@ -3785,7 +3786,7 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, break; } - release_sock(sk); + sockopt_release_sock(sk); return err; } @@ -4049,15 +4050,15 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk, return stats; } -static int do_tcp_getsockopt(struct sock *sk, int level, - int optname, char __user *optval, int __user *optlen) +int do_tcp_getsockopt(struct sock *sk, int level, + int optname, sockptr_t optval, sockptr_t optlen) { struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); int val, len; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; len = min_t(unsigned int, len, sizeof(int)); @@ -4107,15 +4108,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level, case TCP_INFO: { struct tcp_info info; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; tcp_get_info(sk, &info); len = min_t(unsigned int, len, sizeof(info)); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &info, len)) + if (copy_to_sockptr(optval, &info, len)) return -EFAULT; return 0; } @@ -4125,7 +4126,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, size_t sz = 0; int attr; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; ca_ops = icsk->icsk_ca_ops; @@ -4133,9 +4134,9 @@ static int do_tcp_getsockopt(struct sock *sk, int level, sz = ca_ops->get_info(sk, ~0U, &attr, &info); len = min_t(unsigned int, len, sz); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &info, len)) + if (copy_to_sockptr(optval, &info, len)) return -EFAULT; return 0; } @@ -4144,27 +4145,28 @@ static int do_tcp_getsockopt(struct sock *sk, int level, break; case TCP_CONGESTION: - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; len = min_t(unsigned int, len, TCP_CA_NAME_MAX); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, icsk->icsk_ca_ops->name, len)) + if (copy_to_sockptr(optval, icsk->icsk_ca_ops->name, len)) return -EFAULT; return 0; case TCP_ULP: - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; len = min_t(unsigned int, len, TCP_ULP_NAME_MAX); if (!icsk->icsk_ulp_ops) { - if (put_user(0, optlen)) + len = 0; + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; return 0; } - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, icsk->icsk_ulp_ops->name, len)) + if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len)) return -EFAULT; return 0; @@ -4172,15 +4174,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level, u64 key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(u64)]; unsigned int key_len; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; key_len = tcp_fastopen_get_cipher(net, icsk, key) * TCP_FASTOPEN_KEY_LENGTH; len = min_t(unsigned int, len, key_len); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, key, len)) + if (copy_to_sockptr(optval, key, len)) return -EFAULT; return 0; } @@ -4206,7 +4208,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, case TCP_REPAIR_WINDOW: { struct tcp_repair_window opt; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; if (len != sizeof(opt)) @@ -4221,7 +4223,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, opt.rcv_wnd = tp->rcv_wnd; opt.rcv_wup = tp->rcv_wup; - if (copy_to_user(optval, &opt, len)) + if (copy_to_sockptr(optval, &opt, len)) return -EFAULT; return 0; } @@ -4267,35 +4269,35 @@ static int do_tcp_getsockopt(struct sock *sk, int level, val = tp->save_syn; break; case TCP_SAVED_SYN: { - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; - lock_sock(sk); + sockopt_lock_sock(sk); if (tp->saved_syn) { if (len < tcp_saved_syn_len(tp->saved_syn)) { - if (put_user(tcp_saved_syn_len(tp->saved_syn), - optlen)) { - release_sock(sk); + len = tcp_saved_syn_len(tp->saved_syn); + if (copy_to_sockptr(optlen, &len, sizeof(int))) { + sockopt_release_sock(sk); return -EFAULT; } - release_sock(sk); + sockopt_release_sock(sk); return -EINVAL; } len = tcp_saved_syn_len(tp->saved_syn); - if (put_user(len, optlen)) { - release_sock(sk); + if (copy_to_sockptr(optlen, &len, sizeof(int))) { + sockopt_release_sock(sk); return -EFAULT; } - if (copy_to_user(optval, tp->saved_syn->data, len)) { - release_sock(sk); + if (copy_to_sockptr(optval, tp->saved_syn->data, len)) { + sockopt_release_sock(sk); return -EFAULT; } tcp_saved_syn_free(tp); - release_sock(sk); + sockopt_release_sock(sk); } else { - release_sock(sk); + sockopt_release_sock(sk); len = 0; - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; } return 0; @@ -4306,31 +4308,31 @@ static int do_tcp_getsockopt(struct sock *sk, int level, struct tcp_zerocopy_receive zc = {}; int err; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; if (len < 0 || len < offsetofend(struct tcp_zerocopy_receive, length)) return -EINVAL; if (unlikely(len > sizeof(zc))) { - err = check_zeroed_user(optval + sizeof(zc), - len - sizeof(zc)); + err = check_zeroed_sockptr(optval, sizeof(zc), + len - sizeof(zc)); if (err < 1) return err == 0 ? -EINVAL : err; len = sizeof(zc); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; } - if (copy_from_user(&zc, optval, len)) + if (copy_from_sockptr(&zc, optval, len)) return -EFAULT; if (zc.reserved) return -EINVAL; if (zc.msg_flags & ~(TCP_VALID_ZC_MSG_FLAGS)) return -EINVAL; - lock_sock(sk); + sockopt_lock_sock(sk); err = tcp_zerocopy_receive(sk, &zc, &tss); err = BPF_CGROUP_RUN_PROG_GETSOCKOPT_KERN(sk, level, optname, &zc, &len, err); - release_sock(sk); + sockopt_release_sock(sk); if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags)) goto zerocopy_rcv_cmsg; switch (len) { @@ -4360,7 +4362,7 @@ zerocopy_rcv_sk_err: zerocopy_rcv_inq: zc.inq = tcp_inq_hint(sk); zerocopy_rcv_out: - if (!err && copy_to_user(optval, &zc, len)) + if (!err && copy_to_sockptr(optval, &zc, len)) err = -EFAULT; return err; } @@ -4369,9 +4371,9 @@ zerocopy_rcv_out: return -ENOPROTOOPT; } - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &val, len)) + if (copy_to_sockptr(optval, &val, len)) return -EFAULT; return 0; } @@ -4396,7 +4398,8 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, if (level != SOL_TCP) return icsk->icsk_af_ops->getsockopt(sk, level, optname, optval, optlen); - return do_tcp_getsockopt(sk, level, optname, optval, optlen); + return do_tcp_getsockopt(sk, level, optname, USER_SOCKPTR(optval), + USER_SOCKPTR(optlen)); } EXPORT_SYMBOL(tcp_getsockopt); @@ -4442,12 +4445,16 @@ static void __tcp_alloc_md5sig_pool(void) * to memory. See smp_rmb() in tcp_get_md5sig_pool() */ smp_wmb(); - tcp_md5sig_pool_populated = true; + /* Paired with READ_ONCE() from tcp_alloc_md5sig_pool() + * and tcp_get_md5sig_pool(). + */ + WRITE_ONCE(tcp_md5sig_pool_populated, true); } bool tcp_alloc_md5sig_pool(void) { - if (unlikely(!tcp_md5sig_pool_populated)) { + /* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ + if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) { mutex_lock(&tcp_md5sig_mutex); if (!tcp_md5sig_pool_populated) { @@ -4458,7 +4465,8 @@ bool tcp_alloc_md5sig_pool(void) mutex_unlock(&tcp_md5sig_mutex); } - return tcp_md5sig_pool_populated; + /* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ + return READ_ONCE(tcp_md5sig_pool_populated); } EXPORT_SYMBOL(tcp_alloc_md5sig_pool); @@ -4474,7 +4482,8 @@ struct tcp_md5sig_pool *tcp_get_md5sig_pool(void) { local_bh_disable(); - if (tcp_md5sig_pool_populated) { + /* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ + if (READ_ONCE(tcp_md5sig_pool_populated)) { /* coupled with smp_wmb() in __tcp_alloc_md5sig_pool() */ smp_rmb(); return this_cpu_ptr(&tcp_md5sig_pool); @@ -4745,6 +4754,12 @@ void __init tcp_init(void) SLAB_HWCACHE_ALIGN | SLAB_PANIC | SLAB_ACCOUNT, NULL); + tcp_hashinfo.bind2_bucket_cachep = + kmem_cache_create("tcp_bind2_bucket", + sizeof(struct inet_bind2_bucket), 0, + SLAB_HWCACHE_ALIGN | SLAB_PANIC | + SLAB_ACCOUNT, + NULL); /* Size and allocate the main established and bind bucket * hash tables. @@ -4768,7 +4783,7 @@ void __init tcp_init(void) panic("TCP: failed to alloc ehash_locks"); tcp_hashinfo.bhash = alloc_large_system_hash("TCP bind", - sizeof(struct inet_bind_hashbucket), + 2 * sizeof(struct inet_bind_hashbucket), tcp_hashinfo.ehash_mask + 1, 17, /* one slot per 128 KB of memory */ 0, @@ -4777,11 +4792,15 @@ void __init tcp_init(void) 0, 64 * 1024); tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size; + tcp_hashinfo.bhash2 = tcp_hashinfo.bhash + tcp_hashinfo.bhash_size; for (i = 0; i < tcp_hashinfo.bhash_size; i++) { spin_lock_init(&tcp_hashinfo.bhash[i].lock); INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain); + spin_lock_init(&tcp_hashinfo.bhash2[i].lock); + INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain); } + tcp_hashinfo.pernet = false; cnt = tcp_hashinfo.ehash_mask + 1; sysctl_tcp_max_orphans = cnt / 2; diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c index 75a1c985f49a..01b50fa79189 100644 --- a/net/ipv4/tcp_diag.c +++ b/net/ipv4/tcp_diag.c @@ -181,13 +181,21 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin) static void tcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r) { - inet_diag_dump_icsk(&tcp_hashinfo, skb, cb, r); + struct inet_hashinfo *hinfo; + + hinfo = sock_net(cb->skb->sk)->ipv4.tcp_death_row.hashinfo; + + inet_diag_dump_icsk(hinfo, skb, cb, r); } static int tcp_diag_dump_one(struct netlink_callback *cb, const struct inet_diag_req_v2 *req) { - return inet_diag_dump_one_icsk(&tcp_hashinfo, cb, req); + struct inet_hashinfo *hinfo; + + hinfo = sock_net(cb->skb->sk)->ipv4.tcp_death_row.hashinfo; + + return inet_diag_dump_one_icsk(hinfo, cb, req); } #ifdef CONFIG_INET_DIAG_DESTROY @@ -195,9 +203,13 @@ static int tcp_diag_destroy(struct sk_buff *in_skb, const struct inet_diag_req_v2 *req) { struct net *net = sock_net(in_skb->sk); - struct sock *sk = inet_diag_find_one_icsk(net, &tcp_hashinfo, req); + struct inet_hashinfo *hinfo; + struct sock *sk; int err; + hinfo = net->ipv4.tcp_death_row.hashinfo; + sk = inet_diag_find_one_icsk(net, hinfo, req); + if (IS_ERR(sk)) return PTR_ERR(sk); diff --git a/net/ipv4/tcp_fastopen.c b/net/ipv4/tcp_fastopen.c index 825b216d11f5..45cc7f1ca296 100644 --- a/net/ipv4/tcp_fastopen.c +++ b/net/ipv4/tcp_fastopen.c @@ -272,8 +272,9 @@ static struct sock *tcp_fastopen_create_child(struct sock *sk, * The request socket is not added to the ehash * because it's been added to the accept queue directly. */ + req->timeout = tcp_timeout_init(child); inet_csk_reset_xmit_timer(child, ICSK_TIME_RETRANS, - TCP_TIMEOUT_INIT, TCP_RTO_MAX); + req->timeout, TCP_RTO_MAX); refcount_set(&req->rsk_refcnt, 2); diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 5b019ba2b9d2..6376ad915765 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -199,16 +199,18 @@ static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr, /* This will initiate an outgoing connection. */ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) { + struct inet_bind_hashbucket *prev_addr_hashbucket = NULL; struct sockaddr_in *usin = (struct sockaddr_in *)uaddr; + struct inet_timewait_death_row *tcp_death_row; + __be32 daddr, nexthop, prev_sk_rcv_saddr; struct inet_sock *inet = inet_sk(sk); struct tcp_sock *tp = tcp_sk(sk); + struct ip_options_rcu *inet_opt; + struct net *net = sock_net(sk); __be16 orig_sport, orig_dport; - __be32 daddr, nexthop; struct flowi4 *fl4; struct rtable *rt; int err; - struct ip_options_rcu *inet_opt; - struct inet_timewait_death_row *tcp_death_row = sock_net(sk)->ipv4.tcp_death_row; if (addr_len < sizeof(struct sockaddr_in)) return -EINVAL; @@ -234,7 +236,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) if (IS_ERR(rt)) { err = PTR_ERR(rt); if (err == -ENETUNREACH) - IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTNOROUTES); + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); return err; } @@ -246,10 +248,29 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) if (!inet_opt || !inet_opt->opt.srr) daddr = fl4->daddr; - if (!inet->inet_saddr) + tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; + + if (!inet->inet_saddr) { + if (inet_csk(sk)->icsk_bind2_hash) { + prev_addr_hashbucket = inet_bhashfn_portaddr(tcp_death_row->hashinfo, + sk, net, inet->inet_num); + prev_sk_rcv_saddr = sk->sk_rcv_saddr; + } inet->inet_saddr = fl4->saddr; + } + sk_rcv_saddr_set(sk, inet->inet_saddr); + if (prev_addr_hashbucket) { + err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk); + if (err) { + inet->inet_saddr = 0; + sk_rcv_saddr_set(sk, prev_sk_rcv_saddr); + ip_rt_put(rt); + return err; + } + } + if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) { /* Reset inherited state */ tp->rx_opt.ts_recent = 0; @@ -298,8 +319,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) inet->inet_daddr, inet->inet_sport, usin->sin_port)); - tp->tsoffset = secure_tcp_ts_off(sock_net(sk), - inet->inet_saddr, + tp->tsoffset = secure_tcp_ts_off(net, inet->inet_saddr, inet->inet_daddr); } @@ -475,9 +495,9 @@ int tcp_v4_err(struct sk_buff *skb, u32 info) int err; struct net *net = dev_net(skb->dev); - sk = __inet_lookup_established(net, &tcp_hashinfo, iph->daddr, - th->dest, iph->saddr, ntohs(th->source), - inet_iif(skb), 0); + sk = __inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, + iph->daddr, th->dest, iph->saddr, + ntohs(th->source), inet_iif(skb), 0); if (!sk) { __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); return -ENOENT; @@ -740,8 +760,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) * Incoming packet is checked with md5 hash with finding key, * no RST generated if md5 hash doesn't match. */ - sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0, - ip_hdr(skb)->saddr, + sk1 = __inet_lookup_listener(net, net->ipv4.tcp_death_row.hashinfo, + NULL, 0, ip_hdr(skb)->saddr, th->source, ip_hdr(skb)->daddr, ntohs(th->source), dif, sdif); /* don't send rst if it can't find key */ @@ -1709,6 +1729,7 @@ EXPORT_SYMBOL(tcp_v4_do_rcv); int tcp_v4_early_demux(struct sk_buff *skb) { + struct net *net = dev_net(skb->dev); const struct iphdr *iph; const struct tcphdr *th; struct sock *sk; @@ -1725,7 +1746,7 @@ int tcp_v4_early_demux(struct sk_buff *skb) if (th->doff < sizeof(struct tcphdr) / 4) return 0; - sk = __inet_lookup_established(dev_net(skb->dev), &tcp_hashinfo, + sk = __inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, iph->saddr, th->source, iph->daddr, ntohs(th->dest), skb->skb_iif, inet_sdif(skb)); @@ -1951,7 +1972,8 @@ int tcp_v4_rcv(struct sk_buff *skb) th = (const struct tcphdr *)skb->data; iph = ip_hdr(skb); lookup: - sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source, + sk = __inet_lookup_skb(net->ipv4.tcp_death_row.hashinfo, + skb, __tcp_hdrlen(th), th->source, th->dest, sdif, &refcounted); if (!sk) goto no_tcp_socket; @@ -2133,9 +2155,9 @@ do_time_wait: } switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) { case TCP_TW_SYN: { - struct sock *sk2 = inet_lookup_listener(dev_net(skb->dev), - &tcp_hashinfo, skb, - __tcp_hdrlen(th), + struct sock *sk2 = inet_lookup_listener(net, + net->ipv4.tcp_death_row.hashinfo, + skb, __tcp_hdrlen(th), iph->saddr, th->source, iph->daddr, th->dest, inet_iif(skb), @@ -2285,15 +2307,16 @@ static bool seq_sk_match(struct seq_file *seq, const struct sock *sk) */ static void *listening_get_first(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; st->offset = 0; - for (; st->bucket <= tcp_hashinfo.lhash2_mask; st->bucket++) { + for (; st->bucket <= hinfo->lhash2_mask; st->bucket++) { struct inet_listen_hashbucket *ilb2; struct hlist_nulls_node *node; struct sock *sk; - ilb2 = &tcp_hashinfo.lhash2[st->bucket]; + ilb2 = &hinfo->lhash2[st->bucket]; if (hlist_nulls_empty(&ilb2->nulls_head)) continue; @@ -2318,6 +2341,7 @@ static void *listening_get_next(struct seq_file *seq, void *cur) struct tcp_iter_state *st = seq->private; struct inet_listen_hashbucket *ilb2; struct hlist_nulls_node *node; + struct inet_hashinfo *hinfo; struct sock *sk = cur; ++st->num; @@ -2329,7 +2353,8 @@ static void *listening_get_next(struct seq_file *seq, void *cur) return sk; } - ilb2 = &tcp_hashinfo.lhash2[st->bucket]; + hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + ilb2 = &hinfo->lhash2[st->bucket]; spin_unlock(&ilb2->lock); ++st->bucket; return listening_get_first(seq); @@ -2351,9 +2376,10 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos) return rc; } -static inline bool empty_bucket(const struct tcp_iter_state *st) +static inline bool empty_bucket(struct inet_hashinfo *hinfo, + const struct tcp_iter_state *st) { - return hlist_nulls_empty(&tcp_hashinfo.ehash[st->bucket].chain); + return hlist_nulls_empty(&hinfo->ehash[st->bucket].chain); } /* @@ -2362,20 +2388,21 @@ static inline bool empty_bucket(const struct tcp_iter_state *st) */ static void *established_get_first(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; st->offset = 0; - for (; st->bucket <= tcp_hashinfo.ehash_mask; ++st->bucket) { + for (; st->bucket <= hinfo->ehash_mask; ++st->bucket) { struct sock *sk; struct hlist_nulls_node *node; - spinlock_t *lock = inet_ehash_lockp(&tcp_hashinfo, st->bucket); + spinlock_t *lock = inet_ehash_lockp(hinfo, st->bucket); /* Lockless fast path for the common case of empty buckets */ - if (empty_bucket(st)) + if (empty_bucket(hinfo, st)) continue; spin_lock_bh(lock); - sk_nulls_for_each(sk, node, &tcp_hashinfo.ehash[st->bucket].chain) { + sk_nulls_for_each(sk, node, &hinfo->ehash[st->bucket].chain) { if (seq_sk_match(seq, sk)) return sk; } @@ -2387,9 +2414,10 @@ static void *established_get_first(struct seq_file *seq) static void *established_get_next(struct seq_file *seq, void *cur) { - struct sock *sk = cur; - struct hlist_nulls_node *node; + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; + struct hlist_nulls_node *node; + struct sock *sk = cur; ++st->num; ++st->offset; @@ -2401,7 +2429,7 @@ static void *established_get_next(struct seq_file *seq, void *cur) return sk; } - spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); ++st->bucket; return established_get_first(seq); } @@ -2439,6 +2467,7 @@ static void *tcp_get_idx(struct seq_file *seq, loff_t pos) static void *tcp_seek_last_pos(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; int bucket = st->bucket; int offset = st->offset; @@ -2447,7 +2476,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq) switch (st->state) { case TCP_SEQ_STATE_LISTENING: - if (st->bucket > tcp_hashinfo.lhash2_mask) + if (st->bucket > hinfo->lhash2_mask) break; st->state = TCP_SEQ_STATE_LISTENING; rc = listening_get_first(seq); @@ -2459,7 +2488,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq) st->state = TCP_SEQ_STATE_ESTABLISHED; fallthrough; case TCP_SEQ_STATE_ESTABLISHED: - if (st->bucket > tcp_hashinfo.ehash_mask) + if (st->bucket > hinfo->ehash_mask) break; rc = established_get_first(seq); while (offset-- && rc && bucket == st->bucket) @@ -2527,16 +2556,17 @@ EXPORT_SYMBOL(tcp_seq_next); void tcp_seq_stop(struct seq_file *seq, void *v) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; switch (st->state) { case TCP_SEQ_STATE_LISTENING: if (v != SEQ_START_TOKEN) - spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); + spin_unlock(&hinfo->lhash2[st->bucket].lock); break; case TCP_SEQ_STATE_ESTABLISHED: if (v) - spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); break; } } @@ -2731,6 +2761,7 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, struct sock *start_sk) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; @@ -2750,7 +2781,7 @@ static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, expected++; } } - spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); + spin_unlock(&hinfo->lhash2[st->bucket].lock); return expected; } @@ -2758,6 +2789,7 @@ static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, struct sock *start_sk) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; @@ -2777,13 +2809,14 @@ static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, expected++; } } - spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); return expected; } static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; unsigned int expected; @@ -2799,7 +2832,7 @@ static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) st->offset = 0; st->bucket++; if (st->state == TCP_SEQ_STATE_LISTENING && - st->bucket > tcp_hashinfo.lhash2_mask) { + st->bucket > hinfo->lhash2_mask) { st->state = TCP_SEQ_STATE_ESTABLISHED; st->bucket = 0; } @@ -3064,7 +3097,7 @@ struct proto tcp_prot = { .slab_flags = SLAB_TYPESAFE_BY_RCU, .twsk_prot = &tcp_timewait_sock_ops, .rsk_prot = &tcp_request_sock_ops, - .h.hashinfo = &tcp_hashinfo, + .h.hashinfo = NULL, .no_autobind = true, .diag_destroy = tcp_abort, }; @@ -3072,19 +3105,43 @@ EXPORT_SYMBOL(tcp_prot); static void __net_exit tcp_sk_exit(struct net *net) { - struct inet_timewait_death_row *tcp_death_row = net->ipv4.tcp_death_row; - if (net->ipv4.tcp_congestion_control) bpf_module_put(net->ipv4.tcp_congestion_control, net->ipv4.tcp_congestion_control->owner); - if (refcount_dec_and_test(&tcp_death_row->tw_refcount)) - kfree(tcp_death_row); } -static int __net_init tcp_sk_init(struct net *net) +static void __net_init tcp_set_hashinfo(struct net *net) { - int cnt; + struct inet_hashinfo *hinfo; + unsigned int ehash_entries; + struct net *old_net; + + if (net_eq(net, &init_net)) + goto fallback; + + old_net = current->nsproxy->net_ns; + ehash_entries = READ_ONCE(old_net->ipv4.sysctl_tcp_child_ehash_entries); + if (!ehash_entries) + goto fallback; + + ehash_entries = roundup_pow_of_two(ehash_entries); + hinfo = inet_pernet_hashinfo_alloc(&tcp_hashinfo, ehash_entries); + if (!hinfo) { + pr_warn("Failed to allocate TCP ehash (entries: %u) " + "for a netns, fallback to the global one\n", + ehash_entries); +fallback: + hinfo = &tcp_hashinfo; + ehash_entries = tcp_hashinfo.ehash_mask + 1; + } + net->ipv4.tcp_death_row.hashinfo = hinfo; + net->ipv4.tcp_death_row.sysctl_max_tw_buckets = ehash_entries / 2; + net->ipv4.sysctl_max_syn_backlog = max(128U, ehash_entries / 128); +} + +static int __net_init tcp_sk_init(struct net *net) +{ net->ipv4.sysctl_tcp_ecn = 2; net->ipv4.sysctl_tcp_ecn_fallback = 1; @@ -3110,15 +3167,9 @@ static int __net_init tcp_sk_init(struct net *net) net->ipv4.sysctl_tcp_tw_reuse = 2; net->ipv4.sysctl_tcp_no_ssthresh_metrics_save = 1; - net->ipv4.tcp_death_row = kzalloc(sizeof(struct inet_timewait_death_row), GFP_KERNEL); - if (!net->ipv4.tcp_death_row) - return -ENOMEM; - refcount_set(&net->ipv4.tcp_death_row->tw_refcount, 1); - cnt = tcp_hashinfo.ehash_mask + 1; - net->ipv4.tcp_death_row->sysctl_max_tw_buckets = cnt / 2; - net->ipv4.tcp_death_row->hashinfo = &tcp_hashinfo; + refcount_set(&net->ipv4.tcp_death_row.tw_refcount, 1); + tcp_set_hashinfo(net); - net->ipv4.sysctl_max_syn_backlog = max(128, cnt / 128); net->ipv4.sysctl_tcp_sack = 1; net->ipv4.sysctl_tcp_window_scaling = 1; net->ipv4.sysctl_tcp_timestamps = 1; @@ -3180,10 +3231,13 @@ static void __net_exit tcp_sk_exit_batch(struct list_head *net_exit_list) { struct net *net; - inet_twsk_purge(&tcp_hashinfo, AF_INET); + tcp_twsk_purge(net_exit_list, AF_INET); - list_for_each_entry(net, net_exit_list, exit_list) + list_for_each_entry(net, net_exit_list, exit_list) { + inet_pernet_hashinfo_free(net->ipv4.tcp_death_row.hashinfo); + WARN_ON_ONCE(!refcount_dec_and_test(&net->ipv4.tcp_death_row.tw_refcount)); tcp_fastopen_ctx_destroy(net); + } } static struct pernet_operations __net_initdata tcp_sk_ops = { diff --git a/net/ipv4/tcp_metrics.c b/net/ipv4/tcp_metrics.c index d58e672be31c..82f4575f9cd9 100644 --- a/net/ipv4/tcp_metrics.c +++ b/net/ipv4/tcp_metrics.c @@ -969,6 +969,7 @@ static struct genl_family tcp_metrics_nl_family __ro_after_init = { .module = THIS_MODULE, .small_ops = tcp_metrics_nl_ops, .n_small_ops = ARRAY_SIZE(tcp_metrics_nl_ops), + .resv_start_op = TCP_METRICS_CMD_DEL + 1, }; static unsigned int tcpmhash_entries; diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index cb95d88497ae..79f30f026d89 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -247,10 +247,10 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) { const struct inet_connection_sock *icsk = inet_csk(sk); const struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); struct inet_timewait_sock *tw; - struct inet_timewait_death_row *tcp_death_row = sock_net(sk)->ipv4.tcp_death_row; - tw = inet_twsk_alloc(sk, tcp_death_row, state); + tw = inet_twsk_alloc(sk, &net->ipv4.tcp_death_row, state); if (tw) { struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); @@ -319,14 +319,14 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) /* Linkage updates. * Note that access to tw after this point is illegal. */ - inet_twsk_hashdance(tw, sk, &tcp_hashinfo); + inet_twsk_hashdance(tw, sk, net->ipv4.tcp_death_row.hashinfo); local_bh_enable(); } else { /* Sorry, if we're out of memory, just CLOSE this * socket up. We've got bigger problems than * non-graceful socket closings. */ - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPTIMEWAITOVERFLOW); + NET_INC_STATS(net, LINUX_MIB_TCPTIMEWAITOVERFLOW); } tcp_update_metrics(sk); @@ -347,6 +347,26 @@ void tcp_twsk_destructor(struct sock *sk) } EXPORT_SYMBOL_GPL(tcp_twsk_destructor); +void tcp_twsk_purge(struct list_head *net_exit_list, int family) +{ + bool purged_once = false; + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) { + /* The last refcount is decremented in tcp_sk_exit_batch() */ + if (refcount_read(&net->ipv4.tcp_death_row.tw_refcount) == 1) + continue; + + if (net->ipv4.tcp_death_row.hashinfo->pernet) { + inet_twsk_purge(net->ipv4.tcp_death_row.hashinfo, family); + } else if (!purged_once) { + inet_twsk_purge(&tcp_hashinfo, family); + purged_once = true; + } + } +} +EXPORT_SYMBOL_GPL(tcp_twsk_purge); + /* Warning : This function is called without sk_listener being locked. * Be sure to read socket fields once, as their value could change under us. */ @@ -541,6 +561,7 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, newtp->fastopen_req = NULL; RCU_INIT_POINTER(newtp->fastopen_rsk, NULL); + newtp->bpf_chg_cc_inprogress = 0; tcp_bpf_clone(sk, newsk); __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS); diff --git a/net/ipv4/tcp_offload.c b/net/ipv4/tcp_offload.c index 30abde86db45..45dda7889387 100644 --- a/net/ipv4/tcp_offload.c +++ b/net/ipv4/tcp_offload.c @@ -195,12 +195,9 @@ struct sk_buff *tcp_gro_receive(struct list_head *head, struct sk_buff *skb) off = skb_gro_offset(skb); hlen = off + sizeof(*th); - th = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, hlen)) { - th = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!th)) - goto out; - } + th = skb_gro_header(skb, hlen, off); + if (unlikely(!th)) + goto out; thlen = th->doff * 4; if (thlen < sizeof(*th)) @@ -258,7 +255,15 @@ found: mss = skb_shinfo(p)->gso_size; - flush |= (len - 1) >= mss; + /* If skb is a GRO packet, make sure its gso_size matches prior packet mss. + * If it is a single frame, do not aggregate it if its length + * is bigger than our mss. + */ + if (unlikely(skb_is_gso(skb))) + flush |= (mss != skb_shinfo(skb)->gso_size); + else + flush |= (len - 1) >= mss; + flush |= (ntohl(th2->seq) + skb_gro_len(p)) ^ ntohl(th->seq); #ifdef CONFIG_TLS_DEVICE flush |= p->decrypted ^ skb->decrypted; @@ -272,7 +277,12 @@ found: tcp_flag_word(th2) |= flags & (TCP_FLAG_FIN | TCP_FLAG_PSH); out_check_final: - flush = len < mss; + /* Force a flush if last segment is smaller than mss. */ + if (unlikely(skb_is_gso(skb))) + flush = len != NAPI_GRO_CB(skb)->count * skb_shinfo(skb)->gso_size; + else + flush = len < mss; + flush |= (__force int)(flags & (TCP_FLAG_URG | TCP_FLAG_PSH | TCP_FLAG_RST | TCP_FLAG_SYN | TCP_FLAG_FIN)); diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 290019de766d..c69f4d966024 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -1875,15 +1875,20 @@ static void tcp_cwnd_validate(struct sock *sk, bool is_cwnd_limited) const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; struct tcp_sock *tp = tcp_sk(sk); - /* Track the maximum number of outstanding packets in each - * window, and remember whether we were cwnd-limited then. + /* Track the strongest available signal of the degree to which the cwnd + * is fully utilized. If cwnd-limited then remember that fact for the + * current window. If not cwnd-limited then track the maximum number of + * outstanding packets in the current window. (If cwnd-limited then we + * chose to not update tp->max_packets_out to avoid an extra else + * clause with no functional impact.) */ - if (!before(tp->snd_una, tp->max_packets_seq) || - tp->packets_out > tp->max_packets_out || - is_cwnd_limited) { - tp->max_packets_out = tp->packets_out; - tp->max_packets_seq = tp->snd_nxt; + if (!before(tp->snd_una, tp->cwnd_usage_seq) || + is_cwnd_limited || + (!tp->is_cwnd_limited && + tp->packets_out > tp->max_packets_out)) { tp->is_cwnd_limited = is_cwnd_limited; + tp->max_packets_out = tp->packets_out; + tp->cwnd_usage_seq = tp->snd_nxt; } if (tcp_is_cwnd_limited(sk)) { diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index b4dfb82d6ecb..cb79127f45c3 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -428,7 +428,7 @@ static void tcp_fastopen_synack_timer(struct sock *sk, struct request_sock *req) if (!tp->retrans_stamp) tp->retrans_stamp = tcp_time_stamp(tp); inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - TCP_TIMEOUT_INIT << req->num_timeout, TCP_RTO_MAX); + req->timeout << req->num_timeout, TCP_RTO_MAX); } diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 560d9eadeaa5..d63118ce5900 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1801,41 +1801,29 @@ EXPORT_SYMBOL(__skb_recv_udp); int udp_read_skb(struct sock *sk, skb_read_actor_t recv_actor) { - int copied = 0; - - while (1) { - struct sk_buff *skb; - int err, used; - - skb = skb_recv_udp(sk, MSG_DONTWAIT, &err); - if (!skb) - return err; + struct sk_buff *skb; + int err, copied; - if (udp_lib_checksum_complete(skb)) { - __UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, - IS_UDPLITE(sk)); - __UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, - IS_UDPLITE(sk)); - atomic_inc(&sk->sk_drops); - kfree_skb(skb); - continue; - } +try_again: + skb = skb_recv_udp(sk, MSG_DONTWAIT, &err); + if (!skb) + return err; - WARN_ON_ONCE(!skb_set_owner_sk_safe(skb, sk)); - used = recv_actor(sk, skb); - if (used <= 0) { - if (!copied) - copied = used; - kfree_skb(skb); - break; - } else if (used <= skb->len) { - copied += used; - } + if (udp_lib_checksum_complete(skb)) { + int is_udplite = IS_UDPLITE(sk); + struct net *net = sock_net(sk); + __UDP_INC_STATS(net, UDP_MIB_CSUMERRORS, is_udplite); + __UDP_INC_STATS(net, UDP_MIB_INERRORS, is_udplite); + atomic_inc(&sk->sk_drops); kfree_skb(skb); - break; + goto try_again; } + WARN_ON_ONCE(!skb_set_owner_sk_safe(skb, sk)); + copied = recv_actor(sk, skb); + kfree_skb(skb); + return copied; } EXPORT_SYMBOL(udp_read_skb); diff --git a/net/ipv4/xfrm4_tunnel.c b/net/ipv4/xfrm4_tunnel.c index 9d4f418f1bf8..8489fa106583 100644 --- a/net/ipv4/xfrm4_tunnel.c +++ b/net/ipv4/xfrm4_tunnel.c @@ -22,13 +22,17 @@ static int ipip_xfrm_rcv(struct xfrm_state *x, struct sk_buff *skb) return ip_hdr(skb)->protocol; } -static int ipip_init_state(struct xfrm_state *x) +static int ipip_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) { - if (x->props.mode != XFRM_MODE_TUNNEL) + if (x->props.mode != XFRM_MODE_TUNNEL) { + NL_SET_ERR_MSG(extack, "IPv4 tunnel can only be used with tunnel mode"); return -EINVAL; + } - if (x->encap) + if (x->encap) { + NL_SET_ERR_MSG(extack, "IPv4 tunnel is not compatible with encapsulation"); return -EINVAL; + } x->props.header_len = sizeof(struct iphdr); |