0001
0002
0003
0004
0005
0006 #include "ratelimiter.h"
0007 #include <linux/siphash.h>
0008 #include <linux/mm.h>
0009 #include <linux/slab.h>
0010 #include <net/ip.h>
0011
0012 static struct kmem_cache *entry_cache;
0013 static hsiphash_key_t key;
0014 static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
0015 static DEFINE_MUTEX(init_lock);
0016 static u64 init_refcnt;
0017 static atomic_t total_entries = ATOMIC_INIT(0);
0018 static unsigned int max_entries, table_size;
0019 static void wg_ratelimiter_gc_entries(struct work_struct *);
0020 static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
0021 static struct hlist_head *table_v4;
0022 #if IS_ENABLED(CONFIG_IPV6)
0023 static struct hlist_head *table_v6;
0024 #endif
0025
0026 struct ratelimiter_entry {
0027 u64 last_time_ns, tokens, ip;
0028 void *net;
0029 spinlock_t lock;
0030 struct hlist_node hash;
0031 struct rcu_head rcu;
0032 };
0033
0034 enum {
0035 PACKETS_PER_SECOND = 20,
0036 PACKETS_BURSTABLE = 5,
0037 PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
0038 TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
0039 };
0040
0041 static void entry_free(struct rcu_head *rcu)
0042 {
0043 kmem_cache_free(entry_cache,
0044 container_of(rcu, struct ratelimiter_entry, rcu));
0045 atomic_dec(&total_entries);
0046 }
0047
0048 static void entry_uninit(struct ratelimiter_entry *entry)
0049 {
0050 hlist_del_rcu(&entry->hash);
0051 call_rcu(&entry->rcu, entry_free);
0052 }
0053
0054
0055 static void wg_ratelimiter_gc_entries(struct work_struct *work)
0056 {
0057 const u64 now = ktime_get_coarse_boottime_ns();
0058 struct ratelimiter_entry *entry;
0059 struct hlist_node *temp;
0060 unsigned int i;
0061
0062 for (i = 0; i < table_size; ++i) {
0063 spin_lock(&table_lock);
0064 hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
0065 if (unlikely(!work) ||
0066 now - entry->last_time_ns > NSEC_PER_SEC)
0067 entry_uninit(entry);
0068 }
0069 #if IS_ENABLED(CONFIG_IPV6)
0070 hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
0071 if (unlikely(!work) ||
0072 now - entry->last_time_ns > NSEC_PER_SEC)
0073 entry_uninit(entry);
0074 }
0075 #endif
0076 spin_unlock(&table_lock);
0077 if (likely(work))
0078 cond_resched();
0079 }
0080 if (likely(work))
0081 queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
0082 }
0083
0084 bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
0085 {
0086
0087
0088
0089
0090 const u32 net_word = (unsigned long)net;
0091 struct ratelimiter_entry *entry;
0092 struct hlist_head *bucket;
0093 u64 ip;
0094
0095 if (skb->protocol == htons(ETH_P_IP)) {
0096 ip = (u64 __force)ip_hdr(skb)->saddr;
0097 bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
0098 (table_size - 1)];
0099 }
0100 #if IS_ENABLED(CONFIG_IPV6)
0101 else if (skb->protocol == htons(ETH_P_IPV6)) {
0102
0103 memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
0104 bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
0105 (table_size - 1)];
0106 }
0107 #endif
0108 else
0109 return false;
0110 rcu_read_lock();
0111 hlist_for_each_entry_rcu(entry, bucket, hash) {
0112 if (entry->net == net && entry->ip == ip) {
0113 u64 now, tokens;
0114 bool ret;
0115
0116
0117
0118
0119
0120 spin_lock(&entry->lock);
0121 now = ktime_get_coarse_boottime_ns();
0122 tokens = min_t(u64, TOKEN_MAX,
0123 entry->tokens + now -
0124 entry->last_time_ns);
0125 entry->last_time_ns = now;
0126 ret = tokens >= PACKET_COST;
0127 entry->tokens = ret ? tokens - PACKET_COST : tokens;
0128 spin_unlock(&entry->lock);
0129 rcu_read_unlock();
0130 return ret;
0131 }
0132 }
0133 rcu_read_unlock();
0134
0135 if (atomic_inc_return(&total_entries) > max_entries)
0136 goto err_oom;
0137
0138 entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
0139 if (unlikely(!entry))
0140 goto err_oom;
0141
0142 entry->net = net;
0143 entry->ip = ip;
0144 INIT_HLIST_NODE(&entry->hash);
0145 spin_lock_init(&entry->lock);
0146 entry->last_time_ns = ktime_get_coarse_boottime_ns();
0147 entry->tokens = TOKEN_MAX - PACKET_COST;
0148 spin_lock(&table_lock);
0149 hlist_add_head_rcu(&entry->hash, bucket);
0150 spin_unlock(&table_lock);
0151 return true;
0152
0153 err_oom:
0154 atomic_dec(&total_entries);
0155 return false;
0156 }
0157
0158 int wg_ratelimiter_init(void)
0159 {
0160 mutex_lock(&init_lock);
0161 if (++init_refcnt != 1)
0162 goto out;
0163
0164 entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
0165 if (!entry_cache)
0166 goto err;
0167
0168
0169
0170
0171
0172
0173 table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
0174 max_t(unsigned long, 16, roundup_pow_of_two(
0175 (totalram_pages() << PAGE_SHIFT) /
0176 (1U << 14) / sizeof(struct hlist_head)));
0177 max_entries = table_size * 8;
0178
0179 table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
0180 if (unlikely(!table_v4))
0181 goto err_kmemcache;
0182
0183 #if IS_ENABLED(CONFIG_IPV6)
0184 table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
0185 if (unlikely(!table_v6)) {
0186 kvfree(table_v4);
0187 goto err_kmemcache;
0188 }
0189 #endif
0190
0191 queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
0192 get_random_bytes(&key, sizeof(key));
0193 out:
0194 mutex_unlock(&init_lock);
0195 return 0;
0196
0197 err_kmemcache:
0198 kmem_cache_destroy(entry_cache);
0199 err:
0200 --init_refcnt;
0201 mutex_unlock(&init_lock);
0202 return -ENOMEM;
0203 }
0204
0205 void wg_ratelimiter_uninit(void)
0206 {
0207 mutex_lock(&init_lock);
0208 if (!init_refcnt || --init_refcnt)
0209 goto out;
0210
0211 cancel_delayed_work_sync(&gc_work);
0212 wg_ratelimiter_gc_entries(NULL);
0213 rcu_barrier();
0214 kvfree(table_v4);
0215 #if IS_ENABLED(CONFIG_IPV6)
0216 kvfree(table_v6);
0217 #endif
0218 kmem_cache_destroy(entry_cache);
0219 out:
0220 mutex_unlock(&init_lock);
0221 }
0222
0223 #include "selftest/ratelimiter.c"