aboutsummaryrefslogtreecommitdiff
path: root/net/ipv4/raw_diag.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/raw_diag.c')
-rw-r--r--net/ipv4/raw_diag.c22
1 files changed, 13 insertions, 9 deletions
diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c
index b6d92dc7b051..5f208e840d85 100644
--- a/net/ipv4/raw_diag.c
+++ b/net/ipv4/raw_diag.c
@@ -57,31 +57,32 @@ static bool raw_lookup(struct net *net, struct sock *sk,
static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r)
{
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
+ struct hlist_nulls_head *hlist;
+ struct hlist_nulls_node *hnode;
struct sock *sk;
int slot;
if (IS_ERR(hashinfo))
return ERR_CAST(hashinfo);
- read_lock(&hashinfo->lock);
+ rcu_read_lock();
for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) {
- sk_for_each(sk, &hashinfo->ht[slot]) {
+ hlist = &hashinfo->ht[slot];
+ hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (raw_lookup(net, sk, r)) {
/*
* Grab it and keep until we fill
- * diag meaage to be reported, so
+ * diag message to be reported, so
* caller should call sock_put then.
- * We can do that because we're keeping
- * hashinfo->lock here.
*/
- sock_hold(sk);
- goto out_unlock;
+ if (refcount_inc_not_zero(&sk->sk_refcnt))
+ goto out_unlock;
}
}
}
sk = ERR_PTR(-ENOENT);
out_unlock:
- read_unlock(&hashinfo->lock);
+ rcu_read_unlock();
return sk;
}
@@ -141,6 +142,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
struct net *net = sock_net(skb->sk);
struct inet_diag_dump_data *cb_data;
+ struct hlist_nulls_head *hlist;
+ struct hlist_nulls_node *hnode;
int num, s_num, slot, s_slot;
struct sock *sk = NULL;
struct nlattr *bc;
@@ -157,7 +160,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) {
num = 0;
- sk_for_each(sk, &hashinfo->ht[slot]) {
+ hlist = &hashinfo->ht[slot];
+ hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
struct inet_sock *inet = inet_sk(sk);
if (!net_eq(sock_net(sk), net))