0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023 #define pr_fmt(fmt) "MPTCP: " fmt
0024
0025 #include <linux/kernel.h>
0026 #include <linux/module.h>
0027 #include <linux/memblock.h>
0028 #include <linux/ip.h>
0029 #include <linux/tcp.h>
0030 #include <net/sock.h>
0031 #include <net/inet_common.h>
0032 #include <net/protocol.h>
0033 #include <net/mptcp.h>
0034 #include "protocol.h"
0035
0036 #define TOKEN_MAX_CHAIN_LEN 4
0037
0038 struct token_bucket {
0039 spinlock_t lock;
0040 int chain_len;
0041 struct hlist_nulls_head req_chain;
0042 struct hlist_nulls_head msk_chain;
0043 };
0044
0045 static struct token_bucket *token_hash __read_mostly;
0046 static unsigned int token_mask __read_mostly;
0047
0048 static struct token_bucket *token_bucket(u32 token)
0049 {
0050 return &token_hash[token & token_mask];
0051 }
0052
0053
0054 static struct mptcp_subflow_request_sock *
0055 __token_lookup_req(struct token_bucket *t, u32 token)
0056 {
0057 struct mptcp_subflow_request_sock *req;
0058 struct hlist_nulls_node *pos;
0059
0060 hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node)
0061 if (req->token == token)
0062 return req;
0063 return NULL;
0064 }
0065
0066
0067 static struct mptcp_sock *
0068 __token_lookup_msk(struct token_bucket *t, u32 token)
0069 {
0070 struct hlist_nulls_node *pos;
0071 struct sock *sk;
0072
0073 sk_nulls_for_each_rcu(sk, pos, &t->msk_chain)
0074 if (mptcp_sk(sk)->token == token)
0075 return mptcp_sk(sk);
0076 return NULL;
0077 }
0078
0079 static bool __token_bucket_busy(struct token_bucket *t, u32 token)
0080 {
0081 return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN ||
0082 __token_lookup_req(t, token) || __token_lookup_msk(t, token);
0083 }
0084
0085 static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn)
0086 {
0087
0088
0089
0090
0091
0092
0093 get_random_bytes(key, sizeof(u64));
0094 mptcp_crypto_key_sha(*key, token, idsn);
0095 }
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108 int mptcp_token_new_request(struct request_sock *req)
0109 {
0110 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
0111 struct token_bucket *bucket;
0112 u32 token;
0113
0114 mptcp_crypto_key_sha(subflow_req->local_key,
0115 &subflow_req->token,
0116 &subflow_req->idsn);
0117 pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n",
0118 req, subflow_req->local_key, subflow_req->token,
0119 subflow_req->idsn);
0120
0121 token = subflow_req->token;
0122 bucket = token_bucket(token);
0123 spin_lock_bh(&bucket->lock);
0124 if (__token_bucket_busy(bucket, token)) {
0125 spin_unlock_bh(&bucket->lock);
0126 return -EBUSY;
0127 }
0128
0129 hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain);
0130 bucket->chain_len++;
0131 spin_unlock_bh(&bucket->lock);
0132 return 0;
0133 }
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151 int mptcp_token_new_connect(struct sock *sk)
0152 {
0153 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
0154 struct mptcp_sock *msk = mptcp_sk(subflow->conn);
0155 int retries = MPTCP_TOKEN_MAX_RETRIES;
0156 struct token_bucket *bucket;
0157
0158 again:
0159 mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token,
0160 &subflow->idsn);
0161
0162 bucket = token_bucket(subflow->token);
0163 spin_lock_bh(&bucket->lock);
0164 if (__token_bucket_busy(bucket, subflow->token)) {
0165 spin_unlock_bh(&bucket->lock);
0166 if (!--retries)
0167 return -EBUSY;
0168 goto again;
0169 }
0170
0171 pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n",
0172 sk, subflow->local_key, subflow->token, subflow->idsn);
0173
0174 WRITE_ONCE(msk->token, subflow->token);
0175 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
0176 bucket->chain_len++;
0177 spin_unlock_bh(&bucket->lock);
0178 return 0;
0179 }
0180
0181
0182
0183
0184
0185
0186
0187
0188
0189 void mptcp_token_accept(struct mptcp_subflow_request_sock *req,
0190 struct mptcp_sock *msk)
0191 {
0192 struct mptcp_subflow_request_sock *pos;
0193 struct token_bucket *bucket;
0194
0195 bucket = token_bucket(req->token);
0196 spin_lock_bh(&bucket->lock);
0197
0198
0199 pos = __token_lookup_req(bucket, req->token);
0200 if (!WARN_ON_ONCE(pos != req))
0201 hlist_nulls_del_init_rcu(&req->token_node);
0202 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
0203 spin_unlock_bh(&bucket->lock);
0204 }
0205
0206 bool mptcp_token_exists(u32 token)
0207 {
0208 struct hlist_nulls_node *pos;
0209 struct token_bucket *bucket;
0210 struct mptcp_sock *msk;
0211 struct sock *sk;
0212
0213 rcu_read_lock();
0214 bucket = token_bucket(token);
0215
0216 again:
0217 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
0218 msk = mptcp_sk(sk);
0219 if (READ_ONCE(msk->token) == token)
0220 goto found;
0221 }
0222 if (get_nulls_value(pos) != (token & token_mask))
0223 goto again;
0224
0225 rcu_read_unlock();
0226 return false;
0227 found:
0228 rcu_read_unlock();
0229 return true;
0230 }
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240
0241
0242 struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token)
0243 {
0244 struct hlist_nulls_node *pos;
0245 struct token_bucket *bucket;
0246 struct mptcp_sock *msk;
0247 struct sock *sk;
0248
0249 rcu_read_lock();
0250 bucket = token_bucket(token);
0251
0252 again:
0253 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
0254 msk = mptcp_sk(sk);
0255 if (READ_ONCE(msk->token) != token ||
0256 !net_eq(sock_net(sk), net))
0257 continue;
0258
0259 if (!refcount_inc_not_zero(&sk->sk_refcnt))
0260 goto not_found;
0261
0262 if (READ_ONCE(msk->token) != token ||
0263 !net_eq(sock_net(sk), net)) {
0264 sock_put(sk);
0265 goto again;
0266 }
0267 goto found;
0268 }
0269 if (get_nulls_value(pos) != (token & token_mask))
0270 goto again;
0271
0272 not_found:
0273 msk = NULL;
0274
0275 found:
0276 rcu_read_unlock();
0277 return msk;
0278 }
0279 EXPORT_SYMBOL_GPL(mptcp_token_get_sock);
0280
0281
0282
0283
0284
0285
0286
0287
0288
0289
0290
0291
0292
0293 struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
0294 long *s_num)
0295 {
0296 struct mptcp_sock *ret = NULL;
0297 struct hlist_nulls_node *pos;
0298 int slot, num = 0;
0299
0300 for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) {
0301 struct token_bucket *bucket = &token_hash[slot];
0302 struct sock *sk;
0303
0304 num = 0;
0305
0306 if (hlist_nulls_empty(&bucket->msk_chain))
0307 continue;
0308
0309 rcu_read_lock();
0310 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
0311 ++num;
0312 if (!net_eq(sock_net(sk), net))
0313 continue;
0314
0315 if (num <= *s_num)
0316 continue;
0317
0318 if (!refcount_inc_not_zero(&sk->sk_refcnt))
0319 continue;
0320
0321 if (!net_eq(sock_net(sk), net)) {
0322 sock_put(sk);
0323 continue;
0324 }
0325
0326 ret = mptcp_sk(sk);
0327 rcu_read_unlock();
0328 goto out;
0329 }
0330 rcu_read_unlock();
0331 }
0332
0333 out:
0334 *s_slot = slot;
0335 *s_num = num;
0336 return ret;
0337 }
0338 EXPORT_SYMBOL_GPL(mptcp_token_iter_next);
0339
0340
0341
0342
0343
0344
0345
0346 void mptcp_token_destroy_request(struct request_sock *req)
0347 {
0348 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
0349 struct mptcp_subflow_request_sock *pos;
0350 struct token_bucket *bucket;
0351
0352 if (hlist_nulls_unhashed(&subflow_req->token_node))
0353 return;
0354
0355 bucket = token_bucket(subflow_req->token);
0356 spin_lock_bh(&bucket->lock);
0357 pos = __token_lookup_req(bucket, subflow_req->token);
0358 if (!WARN_ON_ONCE(pos != subflow_req)) {
0359 hlist_nulls_del_init_rcu(&pos->token_node);
0360 bucket->chain_len--;
0361 }
0362 spin_unlock_bh(&bucket->lock);
0363 }
0364
0365
0366
0367
0368
0369
0370
0371 void mptcp_token_destroy(struct mptcp_sock *msk)
0372 {
0373 struct token_bucket *bucket;
0374 struct mptcp_sock *pos;
0375
0376 if (sk_unhashed((struct sock *)msk))
0377 return;
0378
0379 bucket = token_bucket(msk->token);
0380 spin_lock_bh(&bucket->lock);
0381 pos = __token_lookup_msk(bucket, msk->token);
0382 if (!WARN_ON_ONCE(pos != msk)) {
0383 __sk_nulls_del_node_init_rcu((struct sock *)pos);
0384 bucket->chain_len--;
0385 }
0386 spin_unlock_bh(&bucket->lock);
0387 WRITE_ONCE(msk->token, 0);
0388 }
0389
0390 void __init mptcp_token_init(void)
0391 {
0392 int i;
0393
0394 token_hash = alloc_large_system_hash("MPTCP token",
0395 sizeof(struct token_bucket),
0396 0,
0397 20,
0398 HASH_ZERO,
0399 NULL,
0400 &token_mask,
0401 0,
0402 64 * 1024);
0403 for (i = 0; i < token_mask + 1; ++i) {
0404 INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i);
0405 INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i);
0406 spin_lock_init(&token_hash[i].lock);
0407 }
0408 }
0409
0410 #if IS_MODULE(CONFIG_MPTCP_KUNIT_TEST)
0411 EXPORT_SYMBOL_GPL(mptcp_token_new_request);
0412 EXPORT_SYMBOL_GPL(mptcp_token_new_connect);
0413 EXPORT_SYMBOL_GPL(mptcp_token_accept);
0414 EXPORT_SYMBOL_GPL(mptcp_token_destroy_request);
0415 EXPORT_SYMBOL_GPL(mptcp_token_destroy);
0416 #endif