0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
0015 #include <linux/in.h>
0016 #include <linux/in6.h>
0017 #include <linux/ip.h>
0018 #include <linux/ipv6.h>
0019 #include <linux/jhash.h>
0020 #include <linux/slab.h>
0021 #include <linux/list.h>
0022 #include <linux/rbtree.h>
0023 #include <linux/module.h>
0024 #include <linux/random.h>
0025 #include <linux/skbuff.h>
0026 #include <linux/spinlock.h>
0027 #include <linux/netfilter/nf_conntrack_tcp.h>
0028 #include <linux/netfilter/x_tables.h>
0029 #include <net/netfilter/nf_conntrack.h>
0030 #include <net/netfilter/nf_conntrack_count.h>
0031 #include <net/netfilter/nf_conntrack_core.h>
0032 #include <net/netfilter/nf_conntrack_tuple.h>
0033 #include <net/netfilter/nf_conntrack_zones.h>
0034
0035 #define CONNCOUNT_SLOTS 256U
0036
0037 #define CONNCOUNT_GC_MAX_NODES 8
0038 #define MAX_KEYLEN 5
0039
0040
0041 struct nf_conncount_tuple {
0042 struct list_head node;
0043 struct nf_conntrack_tuple tuple;
0044 struct nf_conntrack_zone zone;
0045 int cpu;
0046 u32 jiffies32;
0047 };
0048
0049 struct nf_conncount_rb {
0050 struct rb_node node;
0051 struct nf_conncount_list list;
0052 u32 key[MAX_KEYLEN];
0053 struct rcu_head rcu_head;
0054 };
0055
0056 static spinlock_t nf_conncount_locks[CONNCOUNT_SLOTS] __cacheline_aligned_in_smp;
0057
0058 struct nf_conncount_data {
0059 unsigned int keylen;
0060 struct rb_root root[CONNCOUNT_SLOTS];
0061 struct net *net;
0062 struct work_struct gc_work;
0063 unsigned long pending_trees[BITS_TO_LONGS(CONNCOUNT_SLOTS)];
0064 unsigned int gc_tree;
0065 };
0066
0067 static u_int32_t conncount_rnd __read_mostly;
0068 static struct kmem_cache *conncount_rb_cachep __read_mostly;
0069 static struct kmem_cache *conncount_conn_cachep __read_mostly;
0070
0071 static inline bool already_closed(const struct nf_conn *conn)
0072 {
0073 if (nf_ct_protonum(conn) == IPPROTO_TCP)
0074 return conn->proto.tcp.state == TCP_CONNTRACK_TIME_WAIT ||
0075 conn->proto.tcp.state == TCP_CONNTRACK_CLOSE;
0076 else
0077 return false;
0078 }
0079
0080 static int key_diff(const u32 *a, const u32 *b, unsigned int klen)
0081 {
0082 return memcmp(a, b, klen * sizeof(u32));
0083 }
0084
0085 static void conn_free(struct nf_conncount_list *list,
0086 struct nf_conncount_tuple *conn)
0087 {
0088 lockdep_assert_held(&list->list_lock);
0089
0090 list->count--;
0091 list_del(&conn->node);
0092
0093 kmem_cache_free(conncount_conn_cachep, conn);
0094 }
0095
0096 static const struct nf_conntrack_tuple_hash *
0097 find_or_evict(struct net *net, struct nf_conncount_list *list,
0098 struct nf_conncount_tuple *conn)
0099 {
0100 const struct nf_conntrack_tuple_hash *found;
0101 unsigned long a, b;
0102 int cpu = raw_smp_processor_id();
0103 u32 age;
0104
0105 found = nf_conntrack_find_get(net, &conn->zone, &conn->tuple);
0106 if (found)
0107 return found;
0108 b = conn->jiffies32;
0109 a = (u32)jiffies;
0110
0111
0112
0113
0114
0115
0116 age = a - b;
0117 if (conn->cpu == cpu || age >= 2) {
0118 conn_free(list, conn);
0119 return ERR_PTR(-ENOENT);
0120 }
0121
0122 return ERR_PTR(-EAGAIN);
0123 }
0124
0125 static int __nf_conncount_add(struct net *net,
0126 struct nf_conncount_list *list,
0127 const struct nf_conntrack_tuple *tuple,
0128 const struct nf_conntrack_zone *zone)
0129 {
0130 const struct nf_conntrack_tuple_hash *found;
0131 struct nf_conncount_tuple *conn, *conn_n;
0132 struct nf_conn *found_ct;
0133 unsigned int collect = 0;
0134
0135 if (time_is_after_eq_jiffies((unsigned long)list->last_gc))
0136 goto add_new_node;
0137
0138
0139 list_for_each_entry_safe(conn, conn_n, &list->head, node) {
0140 if (collect > CONNCOUNT_GC_MAX_NODES)
0141 break;
0142
0143 found = find_or_evict(net, list, conn);
0144 if (IS_ERR(found)) {
0145
0146 if (PTR_ERR(found) == -EAGAIN) {
0147 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
0148 nf_ct_zone_id(&conn->zone, conn->zone.dir) ==
0149 nf_ct_zone_id(zone, zone->dir))
0150 return 0;
0151 } else {
0152 collect++;
0153 }
0154 continue;
0155 }
0156
0157 found_ct = nf_ct_tuplehash_to_ctrack(found);
0158
0159 if (nf_ct_tuple_equal(&conn->tuple, tuple) &&
0160 nf_ct_zone_equal(found_ct, zone, zone->dir)) {
0161
0162
0163
0164
0165
0166
0167 nf_ct_put(found_ct);
0168 return 0;
0169 } else if (already_closed(found_ct)) {
0170
0171
0172
0173
0174 nf_ct_put(found_ct);
0175 conn_free(list, conn);
0176 collect++;
0177 continue;
0178 }
0179
0180 nf_ct_put(found_ct);
0181 }
0182
0183 add_new_node:
0184 if (WARN_ON_ONCE(list->count > INT_MAX))
0185 return -EOVERFLOW;
0186
0187 conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
0188 if (conn == NULL)
0189 return -ENOMEM;
0190
0191 conn->tuple = *tuple;
0192 conn->zone = *zone;
0193 conn->cpu = raw_smp_processor_id();
0194 conn->jiffies32 = (u32)jiffies;
0195 list_add_tail(&conn->node, &list->head);
0196 list->count++;
0197 list->last_gc = (u32)jiffies;
0198 return 0;
0199 }
0200
0201 int nf_conncount_add(struct net *net,
0202 struct nf_conncount_list *list,
0203 const struct nf_conntrack_tuple *tuple,
0204 const struct nf_conntrack_zone *zone)
0205 {
0206 int ret;
0207
0208
0209 spin_lock_bh(&list->list_lock);
0210 ret = __nf_conncount_add(net, list, tuple, zone);
0211 spin_unlock_bh(&list->list_lock);
0212
0213 return ret;
0214 }
0215 EXPORT_SYMBOL_GPL(nf_conncount_add);
0216
0217 void nf_conncount_list_init(struct nf_conncount_list *list)
0218 {
0219 spin_lock_init(&list->list_lock);
0220 INIT_LIST_HEAD(&list->head);
0221 list->count = 0;
0222 list->last_gc = (u32)jiffies;
0223 }
0224 EXPORT_SYMBOL_GPL(nf_conncount_list_init);
0225
0226
0227 bool nf_conncount_gc_list(struct net *net,
0228 struct nf_conncount_list *list)
0229 {
0230 const struct nf_conntrack_tuple_hash *found;
0231 struct nf_conncount_tuple *conn, *conn_n;
0232 struct nf_conn *found_ct;
0233 unsigned int collected = 0;
0234 bool ret = false;
0235
0236
0237 if (time_is_after_eq_jiffies((unsigned long)READ_ONCE(list->last_gc)))
0238 return false;
0239
0240
0241 if (!spin_trylock(&list->list_lock))
0242 return false;
0243
0244 list_for_each_entry_safe(conn, conn_n, &list->head, node) {
0245 found = find_or_evict(net, list, conn);
0246 if (IS_ERR(found)) {
0247 if (PTR_ERR(found) == -ENOENT)
0248 collected++;
0249 continue;
0250 }
0251
0252 found_ct = nf_ct_tuplehash_to_ctrack(found);
0253 if (already_closed(found_ct)) {
0254
0255
0256
0257
0258 nf_ct_put(found_ct);
0259 conn_free(list, conn);
0260 collected++;
0261 continue;
0262 }
0263
0264 nf_ct_put(found_ct);
0265 if (collected > CONNCOUNT_GC_MAX_NODES)
0266 break;
0267 }
0268
0269 if (!list->count)
0270 ret = true;
0271 list->last_gc = (u32)jiffies;
0272 spin_unlock(&list->list_lock);
0273
0274 return ret;
0275 }
0276 EXPORT_SYMBOL_GPL(nf_conncount_gc_list);
0277
0278 static void __tree_nodes_free(struct rcu_head *h)
0279 {
0280 struct nf_conncount_rb *rbconn;
0281
0282 rbconn = container_of(h, struct nf_conncount_rb, rcu_head);
0283 kmem_cache_free(conncount_rb_cachep, rbconn);
0284 }
0285
0286
0287 static void tree_nodes_free(struct rb_root *root,
0288 struct nf_conncount_rb *gc_nodes[],
0289 unsigned int gc_count)
0290 {
0291 struct nf_conncount_rb *rbconn;
0292
0293 while (gc_count) {
0294 rbconn = gc_nodes[--gc_count];
0295 spin_lock(&rbconn->list.list_lock);
0296 if (!rbconn->list.count) {
0297 rb_erase(&rbconn->node, root);
0298 call_rcu(&rbconn->rcu_head, __tree_nodes_free);
0299 }
0300 spin_unlock(&rbconn->list.list_lock);
0301 }
0302 }
0303
0304 static void schedule_gc_worker(struct nf_conncount_data *data, int tree)
0305 {
0306 set_bit(tree, data->pending_trees);
0307 schedule_work(&data->gc_work);
0308 }
0309
0310 static unsigned int
0311 insert_tree(struct net *net,
0312 struct nf_conncount_data *data,
0313 struct rb_root *root,
0314 unsigned int hash,
0315 const u32 *key,
0316 const struct nf_conntrack_tuple *tuple,
0317 const struct nf_conntrack_zone *zone)
0318 {
0319 struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES];
0320 struct rb_node **rbnode, *parent;
0321 struct nf_conncount_rb *rbconn;
0322 struct nf_conncount_tuple *conn;
0323 unsigned int count = 0, gc_count = 0;
0324 u8 keylen = data->keylen;
0325 bool do_gc = true;
0326
0327 spin_lock_bh(&nf_conncount_locks[hash]);
0328 restart:
0329 parent = NULL;
0330 rbnode = &(root->rb_node);
0331 while (*rbnode) {
0332 int diff;
0333 rbconn = rb_entry(*rbnode, struct nf_conncount_rb, node);
0334
0335 parent = *rbnode;
0336 diff = key_diff(key, rbconn->key, keylen);
0337 if (diff < 0) {
0338 rbnode = &((*rbnode)->rb_left);
0339 } else if (diff > 0) {
0340 rbnode = &((*rbnode)->rb_right);
0341 } else {
0342 int ret;
0343
0344 ret = nf_conncount_add(net, &rbconn->list, tuple, zone);
0345 if (ret)
0346 count = 0;
0347 else
0348 count = rbconn->list.count;
0349 tree_nodes_free(root, gc_nodes, gc_count);
0350 goto out_unlock;
0351 }
0352
0353 if (gc_count >= ARRAY_SIZE(gc_nodes))
0354 continue;
0355
0356 if (do_gc && nf_conncount_gc_list(net, &rbconn->list))
0357 gc_nodes[gc_count++] = rbconn;
0358 }
0359
0360 if (gc_count) {
0361 tree_nodes_free(root, gc_nodes, gc_count);
0362 schedule_gc_worker(data, hash);
0363 gc_count = 0;
0364 do_gc = false;
0365 goto restart;
0366 }
0367
0368
0369 rbconn = kmem_cache_alloc(conncount_rb_cachep, GFP_ATOMIC);
0370 if (rbconn == NULL)
0371 goto out_unlock;
0372
0373 conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC);
0374 if (conn == NULL) {
0375 kmem_cache_free(conncount_rb_cachep, rbconn);
0376 goto out_unlock;
0377 }
0378
0379 conn->tuple = *tuple;
0380 conn->zone = *zone;
0381 memcpy(rbconn->key, key, sizeof(u32) * keylen);
0382
0383 nf_conncount_list_init(&rbconn->list);
0384 list_add(&conn->node, &rbconn->list.head);
0385 count = 1;
0386 rbconn->list.count = count;
0387
0388 rb_link_node_rcu(&rbconn->node, parent, rbnode);
0389 rb_insert_color(&rbconn->node, root);
0390 out_unlock:
0391 spin_unlock_bh(&nf_conncount_locks[hash]);
0392 return count;
0393 }
0394
0395 static unsigned int
0396 count_tree(struct net *net,
0397 struct nf_conncount_data *data,
0398 const u32 *key,
0399 const struct nf_conntrack_tuple *tuple,
0400 const struct nf_conntrack_zone *zone)
0401 {
0402 struct rb_root *root;
0403 struct rb_node *parent;
0404 struct nf_conncount_rb *rbconn;
0405 unsigned int hash;
0406 u8 keylen = data->keylen;
0407
0408 hash = jhash2(key, data->keylen, conncount_rnd) % CONNCOUNT_SLOTS;
0409 root = &data->root[hash];
0410
0411 parent = rcu_dereference_raw(root->rb_node);
0412 while (parent) {
0413 int diff;
0414
0415 rbconn = rb_entry(parent, struct nf_conncount_rb, node);
0416
0417 diff = key_diff(key, rbconn->key, keylen);
0418 if (diff < 0) {
0419 parent = rcu_dereference_raw(parent->rb_left);
0420 } else if (diff > 0) {
0421 parent = rcu_dereference_raw(parent->rb_right);
0422 } else {
0423 int ret;
0424
0425 if (!tuple) {
0426 nf_conncount_gc_list(net, &rbconn->list);
0427 return rbconn->list.count;
0428 }
0429
0430 spin_lock_bh(&rbconn->list.list_lock);
0431
0432
0433
0434 if (rbconn->list.count == 0) {
0435 spin_unlock_bh(&rbconn->list.list_lock);
0436 break;
0437 }
0438
0439
0440 ret = __nf_conncount_add(net, &rbconn->list, tuple, zone);
0441 spin_unlock_bh(&rbconn->list.list_lock);
0442 if (ret)
0443 return 0;
0444 else
0445 return rbconn->list.count;
0446 }
0447 }
0448
0449 if (!tuple)
0450 return 0;
0451
0452 return insert_tree(net, data, root, hash, key, tuple, zone);
0453 }
0454
0455 static void tree_gc_worker(struct work_struct *work)
0456 {
0457 struct nf_conncount_data *data = container_of(work, struct nf_conncount_data, gc_work);
0458 struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES], *rbconn;
0459 struct rb_root *root;
0460 struct rb_node *node;
0461 unsigned int tree, next_tree, gc_count = 0;
0462
0463 tree = data->gc_tree % CONNCOUNT_SLOTS;
0464 root = &data->root[tree];
0465
0466 local_bh_disable();
0467 rcu_read_lock();
0468 for (node = rb_first(root); node != NULL; node = rb_next(node)) {
0469 rbconn = rb_entry(node, struct nf_conncount_rb, node);
0470 if (nf_conncount_gc_list(data->net, &rbconn->list))
0471 gc_count++;
0472 }
0473 rcu_read_unlock();
0474 local_bh_enable();
0475
0476 cond_resched();
0477
0478 spin_lock_bh(&nf_conncount_locks[tree]);
0479 if (gc_count < ARRAY_SIZE(gc_nodes))
0480 goto next;
0481
0482 gc_count = 0;
0483 node = rb_first(root);
0484 while (node != NULL) {
0485 rbconn = rb_entry(node, struct nf_conncount_rb, node);
0486 node = rb_next(node);
0487
0488 if (rbconn->list.count > 0)
0489 continue;
0490
0491 gc_nodes[gc_count++] = rbconn;
0492 if (gc_count >= ARRAY_SIZE(gc_nodes)) {
0493 tree_nodes_free(root, gc_nodes, gc_count);
0494 gc_count = 0;
0495 }
0496 }
0497
0498 tree_nodes_free(root, gc_nodes, gc_count);
0499 next:
0500 clear_bit(tree, data->pending_trees);
0501
0502 next_tree = (tree + 1) % CONNCOUNT_SLOTS;
0503 next_tree = find_next_bit(data->pending_trees, CONNCOUNT_SLOTS, next_tree);
0504
0505 if (next_tree < CONNCOUNT_SLOTS) {
0506 data->gc_tree = next_tree;
0507 schedule_work(work);
0508 }
0509
0510 spin_unlock_bh(&nf_conncount_locks[tree]);
0511 }
0512
0513
0514
0515
0516
0517 unsigned int nf_conncount_count(struct net *net,
0518 struct nf_conncount_data *data,
0519 const u32 *key,
0520 const struct nf_conntrack_tuple *tuple,
0521 const struct nf_conntrack_zone *zone)
0522 {
0523 return count_tree(net, data, key, tuple, zone);
0524 }
0525 EXPORT_SYMBOL_GPL(nf_conncount_count);
0526
0527 struct nf_conncount_data *nf_conncount_init(struct net *net, unsigned int family,
0528 unsigned int keylen)
0529 {
0530 struct nf_conncount_data *data;
0531 int ret, i;
0532
0533 if (keylen % sizeof(u32) ||
0534 keylen / sizeof(u32) > MAX_KEYLEN ||
0535 keylen == 0)
0536 return ERR_PTR(-EINVAL);
0537
0538 net_get_random_once(&conncount_rnd, sizeof(conncount_rnd));
0539
0540 data = kmalloc(sizeof(*data), GFP_KERNEL);
0541 if (!data)
0542 return ERR_PTR(-ENOMEM);
0543
0544 ret = nf_ct_netns_get(net, family);
0545 if (ret < 0) {
0546 kfree(data);
0547 return ERR_PTR(ret);
0548 }
0549
0550 for (i = 0; i < ARRAY_SIZE(data->root); ++i)
0551 data->root[i] = RB_ROOT;
0552
0553 data->keylen = keylen / sizeof(u32);
0554 data->net = net;
0555 INIT_WORK(&data->gc_work, tree_gc_worker);
0556
0557 return data;
0558 }
0559 EXPORT_SYMBOL_GPL(nf_conncount_init);
0560
0561 void nf_conncount_cache_free(struct nf_conncount_list *list)
0562 {
0563 struct nf_conncount_tuple *conn, *conn_n;
0564
0565 list_for_each_entry_safe(conn, conn_n, &list->head, node)
0566 kmem_cache_free(conncount_conn_cachep, conn);
0567 }
0568 EXPORT_SYMBOL_GPL(nf_conncount_cache_free);
0569
0570 static void destroy_tree(struct rb_root *r)
0571 {
0572 struct nf_conncount_rb *rbconn;
0573 struct rb_node *node;
0574
0575 while ((node = rb_first(r)) != NULL) {
0576 rbconn = rb_entry(node, struct nf_conncount_rb, node);
0577
0578 rb_erase(node, r);
0579
0580 nf_conncount_cache_free(&rbconn->list);
0581
0582 kmem_cache_free(conncount_rb_cachep, rbconn);
0583 }
0584 }
0585
0586 void nf_conncount_destroy(struct net *net, unsigned int family,
0587 struct nf_conncount_data *data)
0588 {
0589 unsigned int i;
0590
0591 cancel_work_sync(&data->gc_work);
0592 nf_ct_netns_put(net, family);
0593
0594 for (i = 0; i < ARRAY_SIZE(data->root); ++i)
0595 destroy_tree(&data->root[i]);
0596
0597 kfree(data);
0598 }
0599 EXPORT_SYMBOL_GPL(nf_conncount_destroy);
0600
0601 static int __init nf_conncount_modinit(void)
0602 {
0603 int i;
0604
0605 for (i = 0; i < CONNCOUNT_SLOTS; ++i)
0606 spin_lock_init(&nf_conncount_locks[i]);
0607
0608 conncount_conn_cachep = kmem_cache_create("nf_conncount_tuple",
0609 sizeof(struct nf_conncount_tuple),
0610 0, 0, NULL);
0611 if (!conncount_conn_cachep)
0612 return -ENOMEM;
0613
0614 conncount_rb_cachep = kmem_cache_create("nf_conncount_rb",
0615 sizeof(struct nf_conncount_rb),
0616 0, 0, NULL);
0617 if (!conncount_rb_cachep) {
0618 kmem_cache_destroy(conncount_conn_cachep);
0619 return -ENOMEM;
0620 }
0621
0622 return 0;
0623 }
0624
0625 static void __exit nf_conncount_modexit(void)
0626 {
0627 kmem_cache_destroy(conncount_conn_cachep);
0628 kmem_cache_destroy(conncount_rb_cachep);
0629 }
0630
0631 module_init(nf_conncount_modinit);
0632 module_exit(nf_conncount_modexit);
0633 MODULE_AUTHOR("Jan Engelhardt <jengelh@medozas.de>");
0634 MODULE_AUTHOR("Florian Westphal <fw@strlen.de>");
0635 MODULE_DESCRIPTION("netfilter: count number of connections matching a key");
0636 MODULE_LICENSE("GPL");