Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /*
0003  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
0004  */
0005 
0006 #include "ratelimiter.h"
0007 #include <linux/siphash.h>
0008 #include <linux/mm.h>
0009 #include <linux/slab.h>
0010 #include <net/ip.h>
0011 
0012 static struct kmem_cache *entry_cache;
0013 static hsiphash_key_t key;
0014 static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
0015 static DEFINE_MUTEX(init_lock);
0016 static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
0017 static atomic_t total_entries = ATOMIC_INIT(0);
0018 static unsigned int max_entries, table_size;
0019 static void wg_ratelimiter_gc_entries(struct work_struct *);
0020 static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
0021 static struct hlist_head *table_v4;
0022 #if IS_ENABLED(CONFIG_IPV6)
0023 static struct hlist_head *table_v6;
0024 #endif
0025 
0026 struct ratelimiter_entry {
0027     u64 last_time_ns, tokens, ip;
0028     void *net;
0029     spinlock_t lock;
0030     struct hlist_node hash;
0031     struct rcu_head rcu;
0032 };
0033 
0034 enum {
0035     PACKETS_PER_SECOND = 20,
0036     PACKETS_BURSTABLE = 5,
0037     PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
0038     TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
0039 };
0040 
0041 static void entry_free(struct rcu_head *rcu)
0042 {
0043     kmem_cache_free(entry_cache,
0044             container_of(rcu, struct ratelimiter_entry, rcu));
0045     atomic_dec(&total_entries);
0046 }
0047 
0048 static void entry_uninit(struct ratelimiter_entry *entry)
0049 {
0050     hlist_del_rcu(&entry->hash);
0051     call_rcu(&entry->rcu, entry_free);
0052 }
0053 
0054 /* Calling this function with a NULL work uninits all entries. */
0055 static void wg_ratelimiter_gc_entries(struct work_struct *work)
0056 {
0057     const u64 now = ktime_get_coarse_boottime_ns();
0058     struct ratelimiter_entry *entry;
0059     struct hlist_node *temp;
0060     unsigned int i;
0061 
0062     for (i = 0; i < table_size; ++i) {
0063         spin_lock(&table_lock);
0064         hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
0065             if (unlikely(!work) ||
0066                 now - entry->last_time_ns > NSEC_PER_SEC)
0067                 entry_uninit(entry);
0068         }
0069 #if IS_ENABLED(CONFIG_IPV6)
0070         hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
0071             if (unlikely(!work) ||
0072                 now - entry->last_time_ns > NSEC_PER_SEC)
0073                 entry_uninit(entry);
0074         }
0075 #endif
0076         spin_unlock(&table_lock);
0077         if (likely(work))
0078             cond_resched();
0079     }
0080     if (likely(work))
0081         queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
0082 }
0083 
0084 bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
0085 {
0086     /* We only take the bottom half of the net pointer, so that we can hash
0087      * 3 words in the end. This way, siphash's len param fits into the final
0088      * u32, and we don't incur an extra round.
0089      */
0090     const u32 net_word = (unsigned long)net;
0091     struct ratelimiter_entry *entry;
0092     struct hlist_head *bucket;
0093     u64 ip;
0094 
0095     if (skb->protocol == htons(ETH_P_IP)) {
0096         ip = (u64 __force)ip_hdr(skb)->saddr;
0097         bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
0098                    (table_size - 1)];
0099     }
0100 #if IS_ENABLED(CONFIG_IPV6)
0101     else if (skb->protocol == htons(ETH_P_IPV6)) {
0102         /* Only use 64 bits, so as to ratelimit the whole /64. */
0103         memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
0104         bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
0105                    (table_size - 1)];
0106     }
0107 #endif
0108     else
0109         return false;
0110     rcu_read_lock();
0111     hlist_for_each_entry_rcu(entry, bucket, hash) {
0112         if (entry->net == net && entry->ip == ip) {
0113             u64 now, tokens;
0114             bool ret;
0115             /* Quasi-inspired by nft_limit.c, but this is actually a
0116              * slightly different algorithm. Namely, we incorporate
0117              * the burst as part of the maximum tokens, rather than
0118              * as part of the rate.
0119              */
0120             spin_lock(&entry->lock);
0121             now = ktime_get_coarse_boottime_ns();
0122             tokens = min_t(u64, TOKEN_MAX,
0123                        entry->tokens + now -
0124                            entry->last_time_ns);
0125             entry->last_time_ns = now;
0126             ret = tokens >= PACKET_COST;
0127             entry->tokens = ret ? tokens - PACKET_COST : tokens;
0128             spin_unlock(&entry->lock);
0129             rcu_read_unlock();
0130             return ret;
0131         }
0132     }
0133     rcu_read_unlock();
0134 
0135     if (atomic_inc_return(&total_entries) > max_entries)
0136         goto err_oom;
0137 
0138     entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
0139     if (unlikely(!entry))
0140         goto err_oom;
0141 
0142     entry->net = net;
0143     entry->ip = ip;
0144     INIT_HLIST_NODE(&entry->hash);
0145     spin_lock_init(&entry->lock);
0146     entry->last_time_ns = ktime_get_coarse_boottime_ns();
0147     entry->tokens = TOKEN_MAX - PACKET_COST;
0148     spin_lock(&table_lock);
0149     hlist_add_head_rcu(&entry->hash, bucket);
0150     spin_unlock(&table_lock);
0151     return true;
0152 
0153 err_oom:
0154     atomic_dec(&total_entries);
0155     return false;
0156 }
0157 
0158 int wg_ratelimiter_init(void)
0159 {
0160     mutex_lock(&init_lock);
0161     if (++init_refcnt != 1)
0162         goto out;
0163 
0164     entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
0165     if (!entry_cache)
0166         goto err;
0167 
0168     /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
0169      * but what it shares in common is that it uses a massive hashtable. So,
0170      * we borrow their wisdom about good table sizes on different systems
0171      * dependent on RAM. This calculation here comes from there.
0172      */
0173     table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
0174         max_t(unsigned long, 16, roundup_pow_of_two(
0175             (totalram_pages() << PAGE_SHIFT) /
0176             (1U << 14) / sizeof(struct hlist_head)));
0177     max_entries = table_size * 8;
0178 
0179     table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
0180     if (unlikely(!table_v4))
0181         goto err_kmemcache;
0182 
0183 #if IS_ENABLED(CONFIG_IPV6)
0184     table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
0185     if (unlikely(!table_v6)) {
0186         kvfree(table_v4);
0187         goto err_kmemcache;
0188     }
0189 #endif
0190 
0191     queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
0192     get_random_bytes(&key, sizeof(key));
0193 out:
0194     mutex_unlock(&init_lock);
0195     return 0;
0196 
0197 err_kmemcache:
0198     kmem_cache_destroy(entry_cache);
0199 err:
0200     --init_refcnt;
0201     mutex_unlock(&init_lock);
0202     return -ENOMEM;
0203 }
0204 
0205 void wg_ratelimiter_uninit(void)
0206 {
0207     mutex_lock(&init_lock);
0208     if (!init_refcnt || --init_refcnt)
0209         goto out;
0210 
0211     cancel_delayed_work_sync(&gc_work);
0212     wg_ratelimiter_gc_entries(NULL);
0213     rcu_barrier();
0214     kvfree(table_v4);
0215 #if IS_ENABLED(CONFIG_IPV6)
0216     kvfree(table_v6);
0217 #endif
0218     kmem_cache_destroy(entry_cache);
0219 out:
0220     mutex_unlock(&init_lock);
0221 }
0222 
0223 #include "selftest/ratelimiter.c"