aboutsummaryrefslogtreecommitdiff
path: root/drivers/net/wireguard/receive.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard/receive.c')
-rw-r--r--drivers/net/wireguard/receive.c65
1 files changed, 32 insertions, 33 deletions
diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c
index da3b782ab7d3..91438144e4f7 100644
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -226,40 +226,39 @@ void wg_packet_handshake_receive_worker(struct work_struct *work)
static void keep_key_fresh(struct wg_peer *peer)
{
struct noise_keypair *keypair;
- bool send = false;
+ bool send;
if (peer->sent_lastminute_handshake)
return;
rcu_read_lock_bh();
keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
- if (likely(keypair && READ_ONCE(keypair->sending.is_valid)) &&
- keypair->i_am_the_initiator &&
- unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
- REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT)))
- send = true;
+ send = keypair && READ_ONCE(keypair->sending.is_valid) &&
+ keypair->i_am_the_initiator &&
+ wg_birthdate_has_expired(keypair->sending.birthdate,
+ REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT);
rcu_read_unlock_bh();
- if (send) {
+ if (unlikely(send)) {
peer->sent_lastminute_handshake = true;
wg_packet_send_queued_handshake_initiation(peer, false);
}
}
-static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
+static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
{
struct scatterlist sg[MAX_SKB_FRAGS + 8];
struct sk_buff *trailer;
unsigned int offset;
int num_frags;
- if (unlikely(!key))
+ if (unlikely(!keypair))
return false;
- if (unlikely(!READ_ONCE(key->is_valid) ||
- wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
- key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
- WRITE_ONCE(key->is_valid, false);
+ if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
+ wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
+ keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
+ WRITE_ONCE(keypair->receiving.is_valid, false);
return false;
}
@@ -284,7 +283,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
PACKET_CB(skb)->nonce,
- key->key))
+ keypair->receiving.key))
return false;
/* Another ugly situation of pushing and pulling the header so as to
@@ -299,41 +298,41 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
}
/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static bool counter_validate(union noise_counter *counter, u64 their_counter)
+static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
{
unsigned long index, index_current, top, i;
bool ret = false;
- spin_lock_bh(&counter->receive.lock);
+ spin_lock_bh(&counter->lock);
- if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
+ if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
their_counter >= REJECT_AFTER_MESSAGES))
goto out;
++their_counter;
if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
- counter->receive.counter))
+ counter->counter))
goto out;
index = their_counter >> ilog2(BITS_PER_LONG);
- if (likely(their_counter > counter->receive.counter)) {
- index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+ if (likely(their_counter > counter->counter)) {
+ index_current = counter->counter >> ilog2(BITS_PER_LONG);
top = min_t(unsigned long, index - index_current,
COUNTER_BITS_TOTAL / BITS_PER_LONG);
for (i = 1; i <= top; ++i)
- counter->receive.backtrack[(i + index_current) &
+ counter->backtrack[(i + index_current) &
((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
- counter->receive.counter = their_counter;
+ counter->counter = their_counter;
}
index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
- &counter->receive.backtrack[index]);
+ &counter->backtrack[index]);
out:
- spin_unlock_bh(&counter->receive.lock);
+ spin_unlock_bh(&counter->lock);
return ret;
}
@@ -393,13 +392,11 @@ static void wg_packet_consume_data_done(struct wg_peer *peer,
len = ntohs(ip_hdr(skb)->tot_len);
if (unlikely(len < sizeof(struct iphdr)))
goto dishonest_packet_size;
- if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
- IP_ECN_set_ce(ip_hdr(skb));
+ INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ip_hdr(skb)->tos);
} else if (skb->protocol == htons(ETH_P_IPV6)) {
len = ntohs(ipv6_hdr(skb)->payload_len) +
sizeof(struct ipv6hdr);
- if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
- IP6_ECN_set_ce(skb, ipv6_hdr(skb));
+ INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ipv6_get_dsfield(ipv6_hdr(skb)));
} else {
goto dishonest_packet_type;
}
@@ -475,19 +472,19 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
if (unlikely(state != PACKET_STATE_CRYPTED))
goto next;
- if (unlikely(!counter_validate(&keypair->receiving.counter,
+ if (unlikely(!counter_validate(&keypair->receiving_counter,
PACKET_CB(skb)->nonce))) {
net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
peer->device->dev->name,
PACKET_CB(skb)->nonce,
- keypair->receiving.counter.receive.counter);
+ keypair->receiving_counter.counter);
goto next;
}
if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb)))
goto next;
- wg_reset_packet(skb);
+ wg_reset_packet(skb, false);
wg_packet_consume_data_done(peer, skb, &endpoint);
free = false;
@@ -514,10 +511,12 @@ void wg_packet_decrypt_worker(struct work_struct *work)
struct sk_buff *skb;
while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
- enum packet_state state = likely(decrypt_packet(skb,
- &PACKET_CB(skb)->keypair->receiving)) ?
+ enum packet_state state =
+ likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
wg_queue_enqueue_per_peer_napi(skb, state);
+ if (need_resched())
+ cond_resched();
}
}