aboutsummaryrefslogtreecommitdiff
path: root/drivers/net/wireguard/selftest/allowedips.c
diff options
context:
space:
mode:
authorJason A. Donenfeld2019-12-09 00:27:34 +0100
committerDavid S. Miller2019-12-08 17:48:42 -0800
commite7096c131e5161fa3b8e52a650d7719d2857adfd (patch)
tree269266506f365dd23e8ccf9a16dcfc2d8af1b0c5 /drivers/net/wireguard/selftest/allowedips.c
parente42617b825f8073569da76dc4510bfa019b1c35a (diff)
net: WireGuard secure network tunnel
WireGuard is a layer 3 secure networking tunnel made specifically for the kernel, that aims to be much simpler and easier to audit than IPsec. Extensive documentation and description of the protocol and considerations, along with formal proofs of the cryptography, are available at: * https://www.wireguard.com/ * https://www.wireguard.com/papers/wireguard.pdf This commit implements WireGuard as a simple network device driver, accessible in the usual RTNL way used by virtual network drivers. It makes use of the udp_tunnel APIs, GRO, GSO, NAPI, and the usual set of networking subsystem APIs. It has a somewhat novel multicore queueing system designed for maximum throughput and minimal latency of encryption operations, but it is implemented modestly using workqueues and NAPI. Configuration is done via generic Netlink, and following a review from the Netlink maintainer a year ago, several high profile userspace tools have already implemented the API. This commit also comes with several different tests, both in-kernel tests and out-of-kernel tests based on network namespaces, taking profit of the fact that sockets used by WireGuard intentionally stay in the namespace the WireGuard interface was originally created, exactly like the semantics of userspace tun devices. See wireguard.com/netns/ for pictures and examples. The source code is fairly short, but rather than combining everything into a single file, WireGuard is developed as cleanly separable files, making auditing and comprehension easier. Things are laid out as follows: * noise.[ch], cookie.[ch], messages.h: These implement the bulk of the cryptographic aspects of the protocol, and are mostly data-only in nature, taking in buffers of bytes and spitting out buffers of bytes. They also handle reference counting for their various shared pieces of data, like keys and key lists. * ratelimiter.[ch]: Used as an integral part of cookie.[ch] for ratelimiting certain types of cryptographic operations in accordance with particular WireGuard semantics. * allowedips.[ch], peerlookup.[ch]: The main lookup structures of WireGuard, the former being trie-like with particular semantics, an integral part of the design of the protocol, and the latter just being nice helper functions around the various hashtables we use. * device.[ch]: Implementation of functions for the netdevice and for rtnl, responsible for maintaining the life of a given interface and wiring it up to the rest of WireGuard. * peer.[ch]: Each interface has a list of peers, with helper functions available here for creation, destruction, and reference counting. * socket.[ch]: Implementation of functions related to udp_socket and the general set of kernel socket APIs, for sending and receiving ciphertext UDP packets, and taking care of WireGuard-specific sticky socket routing semantics for the automatic roaming. * netlink.[ch]: Userspace API entry point for configuring WireGuard peers and devices. The API has been implemented by several userspace tools and network management utility, and the WireGuard project distributes the basic wg(8) tool. * queueing.[ch]: Shared function on the rx and tx path for handling the various queues used in the multicore algorithms. * send.c: Handles encrypting outgoing packets in parallel on multiple cores, before sending them in order on a single core, via workqueues and ring buffers. Also handles sending handshake and cookie messages as part of the protocol, in parallel. * receive.c: Handles decrypting incoming packets in parallel on multiple cores, before passing them off in order to be ingested via the rest of the networking subsystem with GRO via the typical NAPI poll function. Also handles receiving handshake and cookie messages as part of the protocol, in parallel. * timers.[ch]: Uses the timer wheel to implement protocol particular event timeouts, and gives a set of very simple event-driven entry point functions for callers. * main.c, version.h: Initialization and deinitialization of the module. * selftest/*.h: Runtime unit tests for some of the most security sensitive functions. * tools/testing/selftests/wireguard/netns.sh: Aforementioned testing script using network namespaces. This commit aims to be as self-contained as possible, implementing WireGuard as a standalone module not needing much special handling or coordination from the network subsystem. I expect for future optimizations to the network stack to positively improve WireGuard, and vice-versa, but for the time being, this exists as intentionally standalone. We introduce a menu option for CONFIG_WIREGUARD, as well as providing a verbose debug log and self-tests via CONFIG_WIREGUARD_DEBUG. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> Cc: David Miller <davem@davemloft.net> Cc: Greg KH <gregkh@linuxfoundation.org> Cc: Linus Torvalds <torvalds@linux-foundation.org> Cc: Herbert Xu <herbert@gondor.apana.org.au> Cc: linux-crypto@vger.kernel.org Cc: linux-kernel@vger.kernel.org Cc: netdev@vger.kernel.org Signed-off-by: David S. Miller <davem@davemloft.net>
Diffstat (limited to 'drivers/net/wireguard/selftest/allowedips.c')
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c683
1 files changed, 683 insertions, 0 deletions
diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c
new file mode 100644
index 000000000000..846db14cb046
--- /dev/null
+++ b/drivers/net/wireguard/selftest/allowedips.c
@@ -0,0 +1,683 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ *
+ * This contains some basic static unit tests for the allowedips data structure.
+ * It also has two additional modes that are disabled and meant to be used by
+ * folks directly playing with this file. If you define the macro
+ * DEBUG_PRINT_TRIE_GRAPHVIZ to be 1, then every time there's a full tree in
+ * memory, it will be printed out as KERN_DEBUG in a format that can be passed
+ * to graphviz (the dot command) to visualize it. If you define the macro
+ * DEBUG_RANDOM_TRIE to be 1, then there will be an extremely costly set of
+ * randomized tests done against a trivial implementation, which may take
+ * upwards of a half-hour to complete. There's no set of users who should be
+ * enabling these, and the only developers that should go anywhere near these
+ * nobs are the ones who are reading this comment.
+ */
+
+#ifdef DEBUG
+
+#include <linux/siphash.h>
+
+static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
+ u8 cidr)
+{
+ swap_endian(dst, src, bits);
+ memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
+ if (cidr)
+ dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
+}
+
+static __init void print_node(struct allowedips_node *node, u8 bits)
+{
+ char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
+ char *fmt_declaration = KERN_DEBUG
+ "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ char *style = "dotted";
+ u8 ip1[16], ip2[16];
+ u32 color = 0;
+
+ if (bits == 32) {
+ fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
+ fmt_declaration = KERN_DEBUG
+ "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
+ } else if (bits == 128) {
+ fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
+ fmt_declaration = KERN_DEBUG
+ "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
+ }
+ if (node->peer) {
+ hsiphash_key_t key = { { 0 } };
+
+ memcpy(&key, &node->peer, sizeof(node->peer));
+ color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 |
+ hsiphash_1u32(0xbabecafe, &key) % 200 << 8 |
+ hsiphash_1u32(0xabad1dea, &key) % 200;
+ style = "bold";
+ }
+ swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr);
+ printk(fmt_declaration, ip1, node->cidr, style, color);
+ if (node->bit[0]) {
+ swap_endian_and_apply_cidr(ip2,
+ rcu_dereference_raw(node->bit[0])->bits, bits,
+ node->cidr);
+ printk(fmt_connection, ip1, node->cidr, ip2,
+ rcu_dereference_raw(node->bit[0])->cidr);
+ print_node(rcu_dereference_raw(node->bit[0]), bits);
+ }
+ if (node->bit[1]) {
+ swap_endian_and_apply_cidr(ip2,
+ rcu_dereference_raw(node->bit[1])->bits,
+ bits, node->cidr);
+ printk(fmt_connection, ip1, node->cidr, ip2,
+ rcu_dereference_raw(node->bit[1])->cidr);
+ print_node(rcu_dereference_raw(node->bit[1]), bits);
+ }
+}
+
+static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
+{
+ printk(KERN_DEBUG "digraph trie {\n");
+ print_node(rcu_dereference_raw(top), bits);
+ printk(KERN_DEBUG "}\n");
+}
+
+enum {
+ NUM_PEERS = 2000,
+ NUM_RAND_ROUTES = 400,
+ NUM_MUTATED_ROUTES = 100,
+ NUM_QUERIES = NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30
+};
+
+struct horrible_allowedips {
+ struct hlist_head head;
+};
+
+struct horrible_allowedips_node {
+ struct hlist_node table;
+ union nf_inet_addr ip;
+ union nf_inet_addr mask;
+ u8 ip_version;
+ void *value;
+};
+
+static __init void horrible_allowedips_init(struct horrible_allowedips *table)
+{
+ INIT_HLIST_HEAD(&table->head);
+}
+
+static __init void horrible_allowedips_free(struct horrible_allowedips *table)
+{
+ struct horrible_allowedips_node *node;
+ struct hlist_node *h;
+
+ hlist_for_each_entry_safe(node, h, &table->head, table) {
+ hlist_del(&node->table);
+ kfree(node);
+ }
+}
+
+static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr)
+{
+ union nf_inet_addr mask;
+
+ memset(&mask, 0x00, 128 / 8);
+ memset(&mask, 0xff, cidr / 8);
+ if (cidr % 32)
+ mask.all[cidr / 32] = (__force u32)htonl(
+ (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
+ return mask;
+}
+
+static __init inline u8 horrible_mask_to_cidr(union nf_inet_addr subnet)
+{
+ return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) +
+ hweight32(subnet.all[2]) + hweight32(subnet.all[3]);
+}
+
+static __init inline void
+horrible_mask_self(struct horrible_allowedips_node *node)
+{
+ if (node->ip_version == 4) {
+ node->ip.ip &= node->mask.ip;
+ } else if (node->ip_version == 6) {
+ node->ip.ip6[0] &= node->mask.ip6[0];
+ node->ip.ip6[1] &= node->mask.ip6[1];
+ node->ip.ip6[2] &= node->mask.ip6[2];
+ node->ip.ip6[3] &= node->mask.ip6[3];
+ }
+}
+
+static __init inline bool
+horrible_match_v4(const struct horrible_allowedips_node *node,
+ struct in_addr *ip)
+{
+ return (ip->s_addr & node->mask.ip) == node->ip.ip;
+}
+
+static __init inline bool
+horrible_match_v6(const struct horrible_allowedips_node *node,
+ struct in6_addr *ip)
+{
+ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) ==
+ node->ip.ip6[0] &&
+ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) ==
+ node->ip.ip6[1] &&
+ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
+ node->ip.ip6[2] &&
+ (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
+}
+
+static __init void
+horrible_insert_ordered(struct horrible_allowedips *table,
+ struct horrible_allowedips_node *node)
+{
+ struct horrible_allowedips_node *other = NULL, *where = NULL;
+ u8 my_cidr = horrible_mask_to_cidr(node->mask);
+
+ hlist_for_each_entry(other, &table->head, table) {
+ if (!memcmp(&other->mask, &node->mask,
+ sizeof(union nf_inet_addr)) &&
+ !memcmp(&other->ip, &node->ip,
+ sizeof(union nf_inet_addr)) &&
+ other->ip_version == node->ip_version) {
+ other->value = node->value;
+ kfree(node);
+ return;
+ }
+ where = other;
+ if (horrible_mask_to_cidr(other->mask) <= my_cidr)
+ break;
+ }
+ if (!other && !where)
+ hlist_add_head(&node->table, &table->head);
+ else if (!other)
+ hlist_add_behind(&node->table, &where->table);
+ else
+ hlist_add_before(&node->table, &where->table);
+}
+
+static __init int
+horrible_allowedips_insert_v4(struct horrible_allowedips *table,
+ struct in_addr *ip, u8 cidr, void *value)
+{
+ struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
+ GFP_KERNEL);
+
+ if (unlikely(!node))
+ return -ENOMEM;
+ node->ip.in = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 4;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+
+static __init int
+horrible_allowedips_insert_v6(struct horrible_allowedips *table,
+ struct in6_addr *ip, u8 cidr, void *value)
+{
+ struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
+ GFP_KERNEL);
+
+ if (unlikely(!node))
+ return -ENOMEM;
+ node->ip.in6 = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 6;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+
+static __init void *
+horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
+ struct in_addr *ip)
+{
+ struct horrible_allowedips_node *node;
+ void *ret = NULL;
+
+ hlist_for_each_entry(node, &table->head, table) {
+ if (node->ip_version != 4)
+ continue;
+ if (horrible_match_v4(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ }
+ return ret;
+}
+
+static __init void *
+horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
+ struct in6_addr *ip)
+{
+ struct horrible_allowedips_node *node;
+ void *ret = NULL;
+
+ hlist_for_each_entry(node, &table->head, table) {
+ if (node->ip_version != 6)
+ continue;
+ if (horrible_match_v6(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ }
+ return ret;
+}
+
+static __init bool randomized_test(void)
+{
+ unsigned int i, j, k, mutate_amount, cidr;
+ u8 ip[16], mutate_mask[16], mutated[16];
+ struct wg_peer **peers, *peer;
+ struct horrible_allowedips h;
+ DEFINE_MUTEX(mutex);
+ struct allowedips t;
+ bool ret = false;
+
+ mutex_init(&mutex);
+
+ wg_allowedips_init(&t);
+ horrible_allowedips_init(&h);
+
+ peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL);
+ if (unlikely(!peers)) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ for (i = 0; i < NUM_PEERS; ++i) {
+ peers[i] = kzalloc(sizeof(*peers[i]), GFP_KERNEL);
+ if (unlikely(!peers[i])) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ kref_init(&peers[i]->refcount);
+ }
+
+ mutex_lock(&mutex);
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ prandom_bytes(ip, 4);
+ cidr = prandom_u32_max(32) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr,
+ peer, &mutex) < 0) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
+ cidr, peer) < 0) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 4);
+ prandom_bytes(mutate_mask, 4);
+ mutate_amount = prandom_u32_max(32);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff
+ << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) |
+ (~mutate_mask[k] &
+ prandom_u32_max(256));
+ cidr = prandom_u32_max(32) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (wg_allowedips_insert_v4(&t,
+ (struct in_addr *)mutated,
+ cidr, peer, &mutex) < 0) {
+ pr_err("allowedips random malloc: FAIL\n");
+ goto free_locked;
+ }
+ if (horrible_allowedips_insert_v4(&h,
+ (struct in_addr *)mutated, cidr, peer)) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ }
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ prandom_bytes(ip, 16);
+ cidr = prandom_u32_max(128) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr,
+ peer, &mutex) < 0) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
+ cidr, peer) < 0) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 16);
+ prandom_bytes(mutate_mask, 16);
+ mutate_amount = prandom_u32_max(128);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff
+ << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) |
+ (~mutate_mask[k] &
+ prandom_u32_max(256));
+ cidr = prandom_u32_max(128) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (wg_allowedips_insert_v6(&t,
+ (struct in6_addr *)mutated,
+ cidr, peer, &mutex) < 0) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ if (horrible_allowedips_insert_v6(
+ &h, (struct in6_addr *)mutated, cidr,
+ peer)) {
+ pr_err("allowedips random self-test malloc: FAIL\n");
+ goto free_locked;
+ }
+ }
+ }
+
+ mutex_unlock(&mutex);
+
+ if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
+ print_tree(t.root4, 32);
+ print_tree(t.root6, 128);
+ }
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 4);
+ if (lookup(t.root4, 32, ip) !=
+ horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
+ pr_err("allowedips random self-test: FAIL\n");
+ goto free;
+ }
+ }
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 16);
+ if (lookup(t.root6, 128, ip) !=
+ horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
+ pr_err("allowedips random self-test: FAIL\n");
+ goto free;
+ }
+ }
+ ret = true;
+
+free:
+ mutex_lock(&mutex);
+free_locked:
+ wg_allowedips_free(&t, &mutex);
+ mutex_unlock(&mutex);
+ horrible_allowedips_free(&h);
+ if (peers) {
+ for (i = 0; i < NUM_PEERS; ++i)
+ kfree(peers[i]);
+ }
+ kfree(peers);
+ return ret;
+}
+
+static __init inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
+{
+ static struct in_addr ip;
+ u8 *split = (u8 *)&ip;
+
+ split[0] = a;
+ split[1] = b;
+ split[2] = c;
+ split[3] = d;
+ return &ip;
+}
+
+static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
+{
+ static struct in6_addr ip;
+ __be32 *split = (__be32 *)&ip;
+
+ split[0] = cpu_to_be32(a);
+ split[1] = cpu_to_be32(b);
+ split[2] = cpu_to_be32(c);
+ split[3] = cpu_to_be32(d);
+ return &ip;
+}
+
+static __init struct wg_peer *init_peer(void)
+{
+ struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL);
+
+ if (!peer)
+ return NULL;
+ kref_init(&peer->refcount);
+ INIT_LIST_HEAD(&peer->allowedips_list);
+ return peer;
+}
+
+#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
+ wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
+ cidr, mem, &mutex)
+
+#define maybe_fail() do { \
+ ++i; \
+ if (!_s) { \
+ pr_info("allowedips self-test %zu: FAIL\n", i); \
+ success = false; \
+ } \
+ } while (0)
+
+#define test(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
+ ip##version(ipa, ipb, ipc, ipd)) == (mem); \
+ maybe_fail(); \
+ } while (0)
+
+#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
+ ip##version(ipa, ipb, ipc, ipd)) != (mem); \
+ maybe_fail(); \
+ } while (0)
+
+#define test_boolean(cond) do { \
+ bool _s = (cond); \
+ maybe_fail(); \
+ } while (0)
+
+bool __init wg_allowedips_selftest(void)
+{
+ bool found_a = false, found_b = false, found_c = false, found_d = false,
+ found_e = false, found_other = false;
+ struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(),
+ *d = init_peer(), *e = init_peer(), *f = init_peer(),
+ *g = init_peer(), *h = init_peer();
+ struct allowedips_node *iter_node;
+ bool success = false;
+ struct allowedips t;
+ DEFINE_MUTEX(mutex);
+ struct in6_addr ip;
+ size_t i = 0, count = 0;
+ __be64 part;
+
+ mutex_init(&mutex);
+ mutex_lock(&mutex);
+ wg_allowedips_init(&t);
+
+ if (!a || !b || !c || !d || !e || !f || !g || !h) {
+ pr_err("allowedips self-test malloc: FAIL\n");
+ goto free;
+ }
+
+ insert(4, a, 192, 168, 4, 0, 24);
+ insert(4, b, 192, 168, 4, 4, 32);
+ insert(4, c, 192, 168, 0, 0, 16);
+ insert(4, d, 192, 95, 5, 64, 27);
+ /* replaces previous entry, and maskself is required */
+ insert(4, c, 192, 95, 5, 65, 27);
+ insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
+ insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
+ insert(4, e, 0, 0, 0, 0, 0);
+ insert(6, e, 0, 0, 0, 0, 0);
+ /* replaces previous entry */
+ insert(6, f, 0, 0, 0, 0, 0);
+ insert(6, g, 0x24046800, 0, 0, 0, 32);
+ /* maskself is required */
+ insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64);
+ insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
+ insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ insert(4, g, 64, 15, 112, 0, 20);
+ /* maskself is required */
+ insert(4, h, 64, 15, 123, 211, 25);
+ insert(4, a, 10, 0, 0, 0, 25);
+ insert(4, b, 10, 0, 0, 128, 25);
+ insert(4, a, 10, 1, 0, 0, 30);
+ insert(4, b, 10, 1, 0, 4, 30);
+ insert(4, c, 10, 1, 0, 8, 29);
+ insert(4, d, 10, 1, 0, 16, 29);
+
+ if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
+ print_tree(t.root4, 32);
+ print_tree(t.root6, 128);
+ }
+
+ success = true;
+
+ test(4, a, 192, 168, 4, 20);
+ test(4, a, 192, 168, 4, 0);
+ test(4, b, 192, 168, 4, 4);
+ test(4, c, 192, 168, 200, 182);
+ test(4, c, 192, 95, 5, 68);
+ test(4, e, 192, 95, 5, 96);
+ test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
+ test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
+ test(6, f, 0x26075300, 0x60006b01, 0, 0);
+ test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
+ test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0, 0);
+ test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
+ test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
+ test(4, g, 64, 15, 116, 26);
+ test(4, g, 64, 15, 127, 3);
+ test(4, g, 64, 15, 123, 1);
+ test(4, h, 64, 15, 123, 128);
+ test(4, h, 64, 15, 123, 129);
+ test(4, a, 10, 0, 0, 52);
+ test(4, b, 10, 0, 0, 220);
+ test(4, a, 10, 1, 0, 2);
+ test(4, b, 10, 1, 0, 6);
+ test(4, c, 10, 1, 0, 10);
+ test(4, d, 10, 1, 0, 20);
+
+ insert(4, a, 1, 0, 0, 0, 32);
+ insert(4, a, 64, 0, 0, 0, 32);
+ insert(4, a, 128, 0, 0, 0, 32);
+ insert(4, a, 192, 0, 0, 0, 32);
+ insert(4, a, 255, 0, 0, 0, 32);
+ wg_allowedips_remove_by_peer(&t, a, &mutex);
+ test_negative(4, a, 1, 0, 0, 0);
+ test_negative(4, a, 64, 0, 0, 0);
+ test_negative(4, a, 128, 0, 0, 0);
+ test_negative(4, a, 192, 0, 0, 0);
+ test_negative(4, a, 255, 0, 0, 0);
+
+ wg_allowedips_free(&t, &mutex);
+ wg_allowedips_init(&t);
+ insert(4, a, 192, 168, 0, 0, 16);
+ insert(4, a, 192, 168, 0, 0, 24);
+ wg_allowedips_remove_by_peer(&t, a, &mutex);
+ test_negative(4, a, 192, 168, 0, 1);
+
+ /* These will hit the WARN_ON(len >= 128) in free_node if something
+ * goes wrong.
+ */
+ for (i = 0; i < 128; ++i) {
+ part = cpu_to_be64(~(1LLU << (i % 64)));
+ memset(&ip, 0xff, 16);
+ memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
+ wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
+ }
+
+ wg_allowedips_free(&t, &mutex);
+
+ wg_allowedips_init(&t);
+ insert(4, a, 192, 95, 5, 93, 27);
+ insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
+ insert(4, a, 10, 1, 0, 20, 29);
+ insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83);
+ insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21);
+ list_for_each_entry(iter_node, &a->allowedips_list, peer_list) {
+ u8 cidr, ip[16] __aligned(__alignof(u64));
+ int family = wg_allowedips_read_node(iter_node, ip, &cidr);
+
+ count++;
+
+ if (cidr == 27 && family == AF_INET &&
+ !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
+ found_a = true;
+ else if (cidr == 128 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
+ sizeof(struct in6_addr)))
+ found_b = true;
+ else if (cidr == 29 && family == AF_INET &&
+ !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
+ found_c = true;
+ else if (cidr == 83 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
+ sizeof(struct in6_addr)))
+ found_d = true;
+ else if (cidr == 21 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075000, 0, 0, 0),
+ sizeof(struct in6_addr)))
+ found_e = true;
+ else
+ found_other = true;
+ }
+ test_boolean(count == 5);
+ test_boolean(found_a);
+ test_boolean(found_b);
+ test_boolean(found_c);
+ test_boolean(found_d);
+ test_boolean(found_e);
+ test_boolean(!found_other);
+
+ if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success)
+ success = randomized_test();
+
+ if (success)
+ pr_info("allowedips self-tests: pass\n");
+
+free:
+ wg_allowedips_free(&t, &mutex);
+ kfree(a);
+ kfree(b);
+ kfree(c);
+ kfree(d);
+ kfree(e);
+ kfree(f);
+ kfree(g);
+ kfree(h);
+ mutex_unlock(&mutex);
+
+ return success;
+}
+
+#undef test_negative
+#undef test
+#undef remove
+#undef insert
+#undef init_peer
+
+#endif