diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/ipv4/tcp_ipv4.c | 237 |
1 files changed, 231 insertions, 6 deletions
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index d38b4379dca4..84ac0135d389 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -2687,6 +2687,15 @@ out: } #ifdef CONFIG_BPF_SYSCALL +struct bpf_tcp_iter_state { + struct tcp_iter_state state; + unsigned int cur_sk; + unsigned int end_sk; + unsigned int max_sk; + struct sock **batch; + bool st_bucket_done; +}; + struct bpf_iter__tcp { __bpf_md_ptr(struct bpf_iter_meta *, meta); __bpf_md_ptr(struct sock_common *, sk_common); @@ -2705,16 +2714,204 @@ static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, return bpf_iter_run_prog(prog, &ctx); } +static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter) +{ + while (iter->cur_sk < iter->end_sk) + sock_put(iter->batch[iter->cur_sk++]); +} + +static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, + unsigned int new_batch_sz) +{ + struct sock **new_batch; + + new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz, + GFP_USER | __GFP_NOWARN); + if (!new_batch) + return -ENOMEM; + + bpf_iter_tcp_put_batch(iter); + kvfree(iter->batch); + iter->batch = new_batch; + iter->max_sk = new_batch_sz; + + return 0; +} + +static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, + struct sock *start_sk) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct inet_connection_sock *icsk; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + icsk = inet_csk(start_sk); + inet_lhash2_for_each_icsk_continue(icsk) { + sk = (struct sock *)icsk; + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + expected++; + } + } + spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); + + return expected; +} + +static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, + struct sock *start_sk) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct hlist_nulls_node *node; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + sk = sk_nulls_next(start_sk); + sk_nulls_for_each_from(sk, node) { + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + expected++; + } + } + spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + + return expected; +} + +static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + unsigned int expected; + bool resized = false; + struct sock *sk; + + /* The st->bucket is done. Directly advance to the next + * bucket instead of having the tcp_seek_last_pos() to skip + * one by one in the current bucket and eventually find out + * it has to advance to the next bucket. + */ + if (iter->st_bucket_done) { + st->offset = 0; + st->bucket++; + if (st->state == TCP_SEQ_STATE_LISTENING && + st->bucket > tcp_hashinfo.lhash2_mask) { + st->state = TCP_SEQ_STATE_ESTABLISHED; + st->bucket = 0; + } + } + +again: + /* Get a new batch */ + iter->cur_sk = 0; + iter->end_sk = 0; + iter->st_bucket_done = false; + + sk = tcp_seek_last_pos(seq); + if (!sk) + return NULL; /* Done */ + + if (st->state == TCP_SEQ_STATE_LISTENING) + expected = bpf_iter_tcp_listening_batch(seq, sk); + else + expected = bpf_iter_tcp_established_batch(seq, sk); + + if (iter->end_sk == expected) { + iter->st_bucket_done = true; + return sk; + } + + if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) { + resized = true; + goto again; + } + + return sk; +} + +static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) +{ + /* bpf iter does not support lseek, so it always + * continue from where it was stop()-ped. + */ + if (*pos) + return bpf_iter_tcp_batch(seq); + + return SEQ_START_TOKEN; +} + +static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct sock *sk; + + /* Whenever seq_next() is called, the iter->cur_sk is + * done with seq_show(), so advance to the next sk in + * the batch. + */ + if (iter->cur_sk < iter->end_sk) { + /* Keeping st->num consistent in tcp_iter_state. + * bpf_iter_tcp does not use st->num. + * meta.seq_num is used instead. + */ + st->num++; + /* Move st->offset to the next sk in the bucket such that + * the future start() will resume at st->offset in + * st->bucket. See tcp_seek_last_pos(). + */ + st->offset++; + sock_put(iter->batch[iter->cur_sk++]); + } + + if (iter->cur_sk < iter->end_sk) + sk = iter->batch[iter->cur_sk]; + else + sk = bpf_iter_tcp_batch(seq); + + ++*pos; + /* Keeping st->last_pos consistent in tcp_iter_state. + * bpf iter does not do lseek, so st->last_pos always equals to *pos. + */ + st->last_pos = *pos; + return sk; +} + static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) { struct bpf_iter_meta meta; struct bpf_prog *prog; struct sock *sk = v; + bool slow; uid_t uid; + int ret; if (v == SEQ_START_TOKEN) return 0; + if (sk_fullsock(sk)) + slow = lock_sock_fast(sk); + + if (unlikely(sk_unhashed(sk))) { + ret = SEQ_SKIP; + goto unlock; + } + if (sk->sk_state == TCP_TIME_WAIT) { uid = 0; } else if (sk->sk_state == TCP_NEW_SYN_RECV) { @@ -2728,11 +2925,18 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) meta.seq = seq; prog = bpf_iter_get_info(&meta, false); - return tcp_prog_seq_show(prog, &meta, v, uid); + ret = tcp_prog_seq_show(prog, &meta, v, uid); + +unlock: + if (sk_fullsock(sk)) + unlock_sock_fast(sk, slow); + return ret; + } static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v) { + struct bpf_tcp_iter_state *iter = seq->private; struct bpf_iter_meta meta; struct bpf_prog *prog; @@ -2743,13 +2947,16 @@ static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v) (void)tcp_prog_seq_show(prog, &meta, v, 0); } - tcp_seq_stop(seq, v); + if (iter->cur_sk < iter->end_sk) { + bpf_iter_tcp_put_batch(iter); + iter->st_bucket_done = false; + } } static const struct seq_operations bpf_iter_tcp_seq_ops = { .show = bpf_iter_tcp_seq_show, - .start = tcp_seq_start, - .next = tcp_seq_next, + .start = bpf_iter_tcp_seq_start, + .next = bpf_iter_tcp_seq_next, .stop = bpf_iter_tcp_seq_stop, }; #endif @@ -3017,21 +3224,39 @@ static struct pernet_operations __net_initdata tcp_sk_ops = { DEFINE_BPF_ITER_FUNC(tcp, struct bpf_iter_meta *meta, struct sock_common *sk_common, uid_t uid) +#define INIT_BATCH_SZ 16 + static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux) { - return bpf_iter_init_seq_net(priv_data, aux); + struct bpf_tcp_iter_state *iter = priv_data; + int err; + + err = bpf_iter_init_seq_net(priv_data, aux); + if (err) + return err; + + err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ); + if (err) { + bpf_iter_fini_seq_net(priv_data); + return err; + } + + return 0; } static void bpf_iter_fini_tcp(void *priv_data) { + struct bpf_tcp_iter_state *iter = priv_data; + bpf_iter_fini_seq_net(priv_data); + kvfree(iter->batch); } static const struct bpf_iter_seq_info tcp_seq_info = { .seq_ops = &bpf_iter_tcp_seq_ops, .init_seq_private = bpf_iter_init_tcp, .fini_seq_private = bpf_iter_fini_tcp, - .seq_priv_size = sizeof(struct tcp_iter_state), + .seq_priv_size = sizeof(struct bpf_tcp_iter_state), }; static struct bpf_iter_reg tcp_reg_info = { |