0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <linux/bpf.h>
0010 #include <linux/btf.h>
0011 #include <linux/err.h>
0012 #include <linux/slab.h>
0013 #include <linux/spinlock.h>
0014 #include <linux/vmalloc.h>
0015 #include <net/ipv6.h>
0016 #include <uapi/linux/btf.h>
0017 #include <linux/btf_ids.h>
0018
0019
0020 #define LPM_TREE_NODE_FLAG_IM BIT(0)
0021
0022 struct lpm_trie_node;
0023
0024 struct lpm_trie_node {
0025 struct rcu_head rcu;
0026 struct lpm_trie_node __rcu *child[2];
0027 u32 prefixlen;
0028 u32 flags;
0029 u8 data[];
0030 };
0031
0032 struct lpm_trie {
0033 struct bpf_map map;
0034 struct lpm_trie_node __rcu *root;
0035 size_t n_entries;
0036 size_t max_prefixlen;
0037 size_t data_size;
0038 spinlock_t lock;
0039 };
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151
0152 static inline int extract_bit(const u8 *data, size_t index)
0153 {
0154 return !!(data[index / 8] & (1 << (7 - (index % 8))));
0155 }
0156
0157
0158
0159
0160
0161
0162
0163
0164
0165 static size_t longest_prefix_match(const struct lpm_trie *trie,
0166 const struct lpm_trie_node *node,
0167 const struct bpf_lpm_trie_key *key)
0168 {
0169 u32 limit = min(node->prefixlen, key->prefixlen);
0170 u32 prefixlen = 0, i = 0;
0171
0172 BUILD_BUG_ON(offsetof(struct lpm_trie_node, data) % sizeof(u32));
0173 BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key, data) % sizeof(u32));
0174
0175 #if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && defined(CONFIG_64BIT)
0176
0177
0178
0179
0180 if (trie->data_size >= 8) {
0181 u64 diff = be64_to_cpu(*(__be64 *)node->data ^
0182 *(__be64 *)key->data);
0183
0184 prefixlen = 64 - fls64(diff);
0185 if (prefixlen >= limit)
0186 return limit;
0187 if (diff)
0188 return prefixlen;
0189 i = 8;
0190 }
0191 #endif
0192
0193 while (trie->data_size >= i + 4) {
0194 u32 diff = be32_to_cpu(*(__be32 *)&node->data[i] ^
0195 *(__be32 *)&key->data[i]);
0196
0197 prefixlen += 32 - fls(diff);
0198 if (prefixlen >= limit)
0199 return limit;
0200 if (diff)
0201 return prefixlen;
0202 i += 4;
0203 }
0204
0205 if (trie->data_size >= i + 2) {
0206 u16 diff = be16_to_cpu(*(__be16 *)&node->data[i] ^
0207 *(__be16 *)&key->data[i]);
0208
0209 prefixlen += 16 - fls(diff);
0210 if (prefixlen >= limit)
0211 return limit;
0212 if (diff)
0213 return prefixlen;
0214 i += 2;
0215 }
0216
0217 if (trie->data_size >= i + 1) {
0218 prefixlen += 8 - fls(node->data[i] ^ key->data[i]);
0219
0220 if (prefixlen >= limit)
0221 return limit;
0222 }
0223
0224 return prefixlen;
0225 }
0226
0227
0228 static void *trie_lookup_elem(struct bpf_map *map, void *_key)
0229 {
0230 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
0231 struct lpm_trie_node *node, *found = NULL;
0232 struct bpf_lpm_trie_key *key = _key;
0233
0234
0235
0236 for (node = rcu_dereference_check(trie->root, rcu_read_lock_bh_held());
0237 node;) {
0238 unsigned int next_bit;
0239 size_t matchlen;
0240
0241
0242
0243
0244
0245 matchlen = longest_prefix_match(trie, node, key);
0246 if (matchlen == trie->max_prefixlen) {
0247 found = node;
0248 break;
0249 }
0250
0251
0252
0253
0254
0255 if (matchlen < node->prefixlen)
0256 break;
0257
0258
0259
0260
0261 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
0262 found = node;
0263
0264
0265
0266
0267
0268 next_bit = extract_bit(key->data, node->prefixlen);
0269 node = rcu_dereference_check(node->child[next_bit],
0270 rcu_read_lock_bh_held());
0271 }
0272
0273 if (!found)
0274 return NULL;
0275
0276 return found->data + trie->data_size;
0277 }
0278
0279 static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
0280 const void *value)
0281 {
0282 struct lpm_trie_node *node;
0283 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
0284
0285 if (value)
0286 size += trie->map.value_size;
0287
0288 node = bpf_map_kmalloc_node(&trie->map, size, GFP_NOWAIT | __GFP_NOWARN,
0289 trie->map.numa_node);
0290 if (!node)
0291 return NULL;
0292
0293 node->flags = 0;
0294
0295 if (value)
0296 memcpy(node->data + trie->data_size, value,
0297 trie->map.value_size);
0298
0299 return node;
0300 }
0301
0302
0303 static int trie_update_elem(struct bpf_map *map,
0304 void *_key, void *value, u64 flags)
0305 {
0306 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
0307 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
0308 struct lpm_trie_node __rcu **slot;
0309 struct bpf_lpm_trie_key *key = _key;
0310 unsigned long irq_flags;
0311 unsigned int next_bit;
0312 size_t matchlen = 0;
0313 int ret = 0;
0314
0315 if (unlikely(flags > BPF_EXIST))
0316 return -EINVAL;
0317
0318 if (key->prefixlen > trie->max_prefixlen)
0319 return -EINVAL;
0320
0321 spin_lock_irqsave(&trie->lock, irq_flags);
0322
0323
0324
0325 if (trie->n_entries == trie->map.max_entries) {
0326 ret = -ENOSPC;
0327 goto out;
0328 }
0329
0330 new_node = lpm_trie_node_alloc(trie, value);
0331 if (!new_node) {
0332 ret = -ENOMEM;
0333 goto out;
0334 }
0335
0336 trie->n_entries++;
0337
0338 new_node->prefixlen = key->prefixlen;
0339 RCU_INIT_POINTER(new_node->child[0], NULL);
0340 RCU_INIT_POINTER(new_node->child[1], NULL);
0341 memcpy(new_node->data, key->data, trie->data_size);
0342
0343
0344
0345
0346
0347
0348 slot = &trie->root;
0349
0350 while ((node = rcu_dereference_protected(*slot,
0351 lockdep_is_held(&trie->lock)))) {
0352 matchlen = longest_prefix_match(trie, node, key);
0353
0354 if (node->prefixlen != matchlen ||
0355 node->prefixlen == key->prefixlen ||
0356 node->prefixlen == trie->max_prefixlen)
0357 break;
0358
0359 next_bit = extract_bit(key->data, node->prefixlen);
0360 slot = &node->child[next_bit];
0361 }
0362
0363
0364
0365
0366 if (!node) {
0367 rcu_assign_pointer(*slot, new_node);
0368 goto out;
0369 }
0370
0371
0372
0373
0374 if (node->prefixlen == matchlen) {
0375 new_node->child[0] = node->child[0];
0376 new_node->child[1] = node->child[1];
0377
0378 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
0379 trie->n_entries--;
0380
0381 rcu_assign_pointer(*slot, new_node);
0382 kfree_rcu(node, rcu);
0383
0384 goto out;
0385 }
0386
0387
0388
0389
0390 if (matchlen == key->prefixlen) {
0391 next_bit = extract_bit(node->data, matchlen);
0392 rcu_assign_pointer(new_node->child[next_bit], node);
0393 rcu_assign_pointer(*slot, new_node);
0394 goto out;
0395 }
0396
0397 im_node = lpm_trie_node_alloc(trie, NULL);
0398 if (!im_node) {
0399 ret = -ENOMEM;
0400 goto out;
0401 }
0402
0403 im_node->prefixlen = matchlen;
0404 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
0405 memcpy(im_node->data, node->data, trie->data_size);
0406
0407
0408 if (extract_bit(key->data, matchlen)) {
0409 rcu_assign_pointer(im_node->child[0], node);
0410 rcu_assign_pointer(im_node->child[1], new_node);
0411 } else {
0412 rcu_assign_pointer(im_node->child[0], new_node);
0413 rcu_assign_pointer(im_node->child[1], node);
0414 }
0415
0416
0417 rcu_assign_pointer(*slot, im_node);
0418
0419 out:
0420 if (ret) {
0421 if (new_node)
0422 trie->n_entries--;
0423
0424 kfree(new_node);
0425 kfree(im_node);
0426 }
0427
0428 spin_unlock_irqrestore(&trie->lock, irq_flags);
0429
0430 return ret;
0431 }
0432
0433
0434 static int trie_delete_elem(struct bpf_map *map, void *_key)
0435 {
0436 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
0437 struct bpf_lpm_trie_key *key = _key;
0438 struct lpm_trie_node __rcu **trim, **trim2;
0439 struct lpm_trie_node *node, *parent;
0440 unsigned long irq_flags;
0441 unsigned int next_bit;
0442 size_t matchlen = 0;
0443 int ret = 0;
0444
0445 if (key->prefixlen > trie->max_prefixlen)
0446 return -EINVAL;
0447
0448 spin_lock_irqsave(&trie->lock, irq_flags);
0449
0450
0451
0452
0453
0454
0455
0456 trim = &trie->root;
0457 trim2 = trim;
0458 parent = NULL;
0459 while ((node = rcu_dereference_protected(
0460 *trim, lockdep_is_held(&trie->lock)))) {
0461 matchlen = longest_prefix_match(trie, node, key);
0462
0463 if (node->prefixlen != matchlen ||
0464 node->prefixlen == key->prefixlen)
0465 break;
0466
0467 parent = node;
0468 trim2 = trim;
0469 next_bit = extract_bit(key->data, node->prefixlen);
0470 trim = &node->child[next_bit];
0471 }
0472
0473 if (!node || node->prefixlen != key->prefixlen ||
0474 node->prefixlen != matchlen ||
0475 (node->flags & LPM_TREE_NODE_FLAG_IM)) {
0476 ret = -ENOENT;
0477 goto out;
0478 }
0479
0480 trie->n_entries--;
0481
0482
0483
0484
0485 if (rcu_access_pointer(node->child[0]) &&
0486 rcu_access_pointer(node->child[1])) {
0487 node->flags |= LPM_TREE_NODE_FLAG_IM;
0488 goto out;
0489 }
0490
0491
0492
0493
0494
0495
0496
0497
0498 if (parent && (parent->flags & LPM_TREE_NODE_FLAG_IM) &&
0499 !node->child[0] && !node->child[1]) {
0500 if (node == rcu_access_pointer(parent->child[0]))
0501 rcu_assign_pointer(
0502 *trim2, rcu_access_pointer(parent->child[1]));
0503 else
0504 rcu_assign_pointer(
0505 *trim2, rcu_access_pointer(parent->child[0]));
0506 kfree_rcu(parent, rcu);
0507 kfree_rcu(node, rcu);
0508 goto out;
0509 }
0510
0511
0512
0513
0514
0515 if (node->child[0])
0516 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[0]));
0517 else if (node->child[1])
0518 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
0519 else
0520 RCU_INIT_POINTER(*trim, NULL);
0521 kfree_rcu(node, rcu);
0522
0523 out:
0524 spin_unlock_irqrestore(&trie->lock, irq_flags);
0525
0526 return ret;
0527 }
0528
0529 #define LPM_DATA_SIZE_MAX 256
0530 #define LPM_DATA_SIZE_MIN 1
0531
0532 #define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
0533 sizeof(struct lpm_trie_node))
0534 #define LPM_VAL_SIZE_MIN 1
0535
0536 #define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
0537 #define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
0538 #define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
0539
0540 #define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE | \
0541 BPF_F_ACCESS_MASK)
0542
0543 static struct bpf_map *trie_alloc(union bpf_attr *attr)
0544 {
0545 struct lpm_trie *trie;
0546
0547 if (!bpf_capable())
0548 return ERR_PTR(-EPERM);
0549
0550
0551 if (attr->max_entries == 0 ||
0552 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
0553 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
0554 !bpf_map_flags_access_ok(attr->map_flags) ||
0555 attr->key_size < LPM_KEY_SIZE_MIN ||
0556 attr->key_size > LPM_KEY_SIZE_MAX ||
0557 attr->value_size < LPM_VAL_SIZE_MIN ||
0558 attr->value_size > LPM_VAL_SIZE_MAX)
0559 return ERR_PTR(-EINVAL);
0560
0561 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN | __GFP_ACCOUNT);
0562 if (!trie)
0563 return ERR_PTR(-ENOMEM);
0564
0565
0566 bpf_map_init_from_attr(&trie->map, attr);
0567 trie->data_size = attr->key_size -
0568 offsetof(struct bpf_lpm_trie_key, data);
0569 trie->max_prefixlen = trie->data_size * 8;
0570
0571 spin_lock_init(&trie->lock);
0572
0573 return &trie->map;
0574 }
0575
0576 static void trie_free(struct bpf_map *map)
0577 {
0578 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
0579 struct lpm_trie_node __rcu **slot;
0580 struct lpm_trie_node *node;
0581
0582
0583
0584
0585
0586
0587 for (;;) {
0588 slot = &trie->root;
0589
0590 for (;;) {
0591 node = rcu_dereference_protected(*slot, 1);
0592 if (!node)
0593 goto out;
0594
0595 if (rcu_access_pointer(node->child[0])) {
0596 slot = &node->child[0];
0597 continue;
0598 }
0599
0600 if (rcu_access_pointer(node->child[1])) {
0601 slot = &node->child[1];
0602 continue;
0603 }
0604
0605 kfree(node);
0606 RCU_INIT_POINTER(*slot, NULL);
0607 break;
0608 }
0609 }
0610
0611 out:
0612 kfree(trie);
0613 }
0614
0615 static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
0616 {
0617 struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
0618 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
0619 struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
0620 struct lpm_trie_node **node_stack = NULL;
0621 int err = 0, stack_ptr = -1;
0622 unsigned int next_bit;
0623 size_t matchlen;
0624
0625
0626
0627
0628
0629
0630
0631
0632
0633
0634
0635
0636
0637 search_root = rcu_dereference(trie->root);
0638 if (!search_root)
0639 return -ENOENT;
0640
0641
0642 if (!key || key->prefixlen > trie->max_prefixlen)
0643 goto find_leftmost;
0644
0645 node_stack = kmalloc_array(trie->max_prefixlen,
0646 sizeof(struct lpm_trie_node *),
0647 GFP_ATOMIC | __GFP_NOWARN);
0648 if (!node_stack)
0649 return -ENOMEM;
0650
0651
0652 for (node = search_root; node;) {
0653 node_stack[++stack_ptr] = node;
0654 matchlen = longest_prefix_match(trie, node, key);
0655 if (node->prefixlen != matchlen ||
0656 node->prefixlen == key->prefixlen)
0657 break;
0658
0659 next_bit = extract_bit(key->data, node->prefixlen);
0660 node = rcu_dereference(node->child[next_bit]);
0661 }
0662 if (!node || node->prefixlen != key->prefixlen ||
0663 (node->flags & LPM_TREE_NODE_FLAG_IM))
0664 goto find_leftmost;
0665
0666
0667
0668
0669 node = node_stack[stack_ptr];
0670 while (stack_ptr > 0) {
0671 parent = node_stack[stack_ptr - 1];
0672 if (rcu_dereference(parent->child[0]) == node) {
0673 search_root = rcu_dereference(parent->child[1]);
0674 if (search_root)
0675 goto find_leftmost;
0676 }
0677 if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
0678 next_node = parent;
0679 goto do_copy;
0680 }
0681
0682 node = parent;
0683 stack_ptr--;
0684 }
0685
0686
0687 err = -ENOENT;
0688 goto free_stack;
0689
0690 find_leftmost:
0691
0692
0693
0694 for (node = search_root; node;) {
0695 if (node->flags & LPM_TREE_NODE_FLAG_IM) {
0696 node = rcu_dereference(node->child[0]);
0697 } else {
0698 next_node = node;
0699 node = rcu_dereference(node->child[0]);
0700 if (!node)
0701 node = rcu_dereference(next_node->child[1]);
0702 }
0703 }
0704 do_copy:
0705 next_key->prefixlen = next_node->prefixlen;
0706 memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
0707 next_node->data, trie->data_size);
0708 free_stack:
0709 kfree(node_stack);
0710 return err;
0711 }
0712
0713 static int trie_check_btf(const struct bpf_map *map,
0714 const struct btf *btf,
0715 const struct btf_type *key_type,
0716 const struct btf_type *value_type)
0717 {
0718
0719 return BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ?
0720 -EINVAL : 0;
0721 }
0722
0723 BTF_ID_LIST_SINGLE(trie_map_btf_ids, struct, lpm_trie)
0724 const struct bpf_map_ops trie_map_ops = {
0725 .map_meta_equal = bpf_map_meta_equal,
0726 .map_alloc = trie_alloc,
0727 .map_free = trie_free,
0728 .map_get_next_key = trie_get_next_key,
0729 .map_lookup_elem = trie_lookup_elem,
0730 .map_update_elem = trie_update_elem,
0731 .map_delete_elem = trie_delete_elem,
0732 .map_lookup_batch = generic_map_lookup_batch,
0733 .map_update_batch = generic_map_update_batch,
0734 .map_delete_batch = generic_map_delete_batch,
0735 .map_check_btf = trie_check_btf,
0736 .map_btf_id = &trie_map_btf_ids[0],
0737 };