Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Multipath TCP token management
0003  * Copyright (c) 2017 - 2019, Intel Corporation.
0004  *
0005  * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org,
0006  *       authored by:
0007  *
0008  *       Sébastien Barré <sebastien.barre@uclouvain.be>
0009  *       Christoph Paasch <christoph.paasch@uclouvain.be>
0010  *       Jaakko Korkeaniemi <jaakko.korkeaniemi@aalto.fi>
0011  *       Gregory Detal <gregory.detal@uclouvain.be>
0012  *       Fabien Duchêne <fabien.duchene@uclouvain.be>
0013  *       Andreas Seelinger <Andreas.Seelinger@rwth-aachen.de>
0014  *       Lavkesh Lahngir <lavkesh51@gmail.com>
0015  *       Andreas Ripke <ripke@neclab.eu>
0016  *       Vlad Dogaru <vlad.dogaru@intel.com>
0017  *       Octavian Purdila <octavian.purdila@intel.com>
0018  *       John Ronan <jronan@tssg.org>
0019  *       Catalin Nicutar <catalin.nicutar@gmail.com>
0020  *       Brandon Heller <brandonh@stanford.edu>
0021  */
0022 
0023 #define pr_fmt(fmt) "MPTCP: " fmt
0024 
0025 #include <linux/kernel.h>
0026 #include <linux/module.h>
0027 #include <linux/memblock.h>
0028 #include <linux/ip.h>
0029 #include <linux/tcp.h>
0030 #include <net/sock.h>
0031 #include <net/inet_common.h>
0032 #include <net/protocol.h>
0033 #include <net/mptcp.h>
0034 #include "protocol.h"
0035 
0036 #define TOKEN_MAX_CHAIN_LEN 4
0037 
0038 struct token_bucket {
0039     spinlock_t      lock;
0040     int         chain_len;
0041     struct hlist_nulls_head req_chain;
0042     struct hlist_nulls_head msk_chain;
0043 };
0044 
0045 static struct token_bucket *token_hash __read_mostly;
0046 static unsigned int token_mask __read_mostly;
0047 
0048 static struct token_bucket *token_bucket(u32 token)
0049 {
0050     return &token_hash[token & token_mask];
0051 }
0052 
0053 /* called with bucket lock held */
0054 static struct mptcp_subflow_request_sock *
0055 __token_lookup_req(struct token_bucket *t, u32 token)
0056 {
0057     struct mptcp_subflow_request_sock *req;
0058     struct hlist_nulls_node *pos;
0059 
0060     hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node)
0061         if (req->token == token)
0062             return req;
0063     return NULL;
0064 }
0065 
0066 /* called with bucket lock held */
0067 static struct mptcp_sock *
0068 __token_lookup_msk(struct token_bucket *t, u32 token)
0069 {
0070     struct hlist_nulls_node *pos;
0071     struct sock *sk;
0072 
0073     sk_nulls_for_each_rcu(sk, pos, &t->msk_chain)
0074         if (mptcp_sk(sk)->token == token)
0075             return mptcp_sk(sk);
0076     return NULL;
0077 }
0078 
0079 static bool __token_bucket_busy(struct token_bucket *t, u32 token)
0080 {
0081     return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN ||
0082            __token_lookup_req(t, token) || __token_lookup_msk(t, token);
0083 }
0084 
0085 static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn)
0086 {
0087     /* we might consider a faster version that computes the key as a
0088      * hash of some information available in the MPTCP socket. Use
0089      * random data at the moment, as it's probably the safest option
0090      * in case multiple sockets are opened in different namespaces at
0091      * the same time.
0092      */
0093     get_random_bytes(key, sizeof(u64));
0094     mptcp_crypto_key_sha(*key, token, idsn);
0095 }
0096 
0097 /**
0098  * mptcp_token_new_request - create new key/idsn/token for subflow_request
0099  * @req: the request socket
0100  *
0101  * This function is called when a new mptcp connection is coming in.
0102  *
0103  * It creates a unique token to identify the new mptcp connection,
0104  * a secret local key and the initial data sequence number (idsn).
0105  *
0106  * Returns 0 on success.
0107  */
0108 int mptcp_token_new_request(struct request_sock *req)
0109 {
0110     struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
0111     struct token_bucket *bucket;
0112     u32 token;
0113 
0114     mptcp_crypto_key_sha(subflow_req->local_key,
0115                  &subflow_req->token,
0116                  &subflow_req->idsn);
0117     pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n",
0118          req, subflow_req->local_key, subflow_req->token,
0119          subflow_req->idsn);
0120 
0121     token = subflow_req->token;
0122     bucket = token_bucket(token);
0123     spin_lock_bh(&bucket->lock);
0124     if (__token_bucket_busy(bucket, token)) {
0125         spin_unlock_bh(&bucket->lock);
0126         return -EBUSY;
0127     }
0128 
0129     hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain);
0130     bucket->chain_len++;
0131     spin_unlock_bh(&bucket->lock);
0132     return 0;
0133 }
0134 
0135 /**
0136  * mptcp_token_new_connect - create new key/idsn/token for subflow
0137  * @sk: the socket that will initiate a connection
0138  *
0139  * This function is called when a new outgoing mptcp connection is
0140  * initiated.
0141  *
0142  * It creates a unique token to identify the new mptcp connection,
0143  * a secret local key and the initial data sequence number (idsn).
0144  *
0145  * On success, the mptcp connection can be found again using
0146  * the computed token at a later time, this is needed to process
0147  * join requests.
0148  *
0149  * returns 0 on success.
0150  */
0151 int mptcp_token_new_connect(struct sock *sk)
0152 {
0153     struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
0154     struct mptcp_sock *msk = mptcp_sk(subflow->conn);
0155     int retries = MPTCP_TOKEN_MAX_RETRIES;
0156     struct token_bucket *bucket;
0157 
0158 again:
0159     mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token,
0160                  &subflow->idsn);
0161 
0162     bucket = token_bucket(subflow->token);
0163     spin_lock_bh(&bucket->lock);
0164     if (__token_bucket_busy(bucket, subflow->token)) {
0165         spin_unlock_bh(&bucket->lock);
0166         if (!--retries)
0167             return -EBUSY;
0168         goto again;
0169     }
0170 
0171     pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n",
0172          sk, subflow->local_key, subflow->token, subflow->idsn);
0173 
0174     WRITE_ONCE(msk->token, subflow->token);
0175     __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
0176     bucket->chain_len++;
0177     spin_unlock_bh(&bucket->lock);
0178     return 0;
0179 }
0180 
0181 /**
0182  * mptcp_token_accept - replace a req sk with full sock in token hash
0183  * @req: the request socket to be removed
0184  * @msk: the just cloned socket linked to the new connection
0185  *
0186  * Called when a SYN packet creates a new logical connection, i.e.
0187  * is not a join request.
0188  */
0189 void mptcp_token_accept(struct mptcp_subflow_request_sock *req,
0190             struct mptcp_sock *msk)
0191 {
0192     struct mptcp_subflow_request_sock *pos;
0193     struct token_bucket *bucket;
0194 
0195     bucket = token_bucket(req->token);
0196     spin_lock_bh(&bucket->lock);
0197 
0198     /* pedantic lookup check for the moved token */
0199     pos = __token_lookup_req(bucket, req->token);
0200     if (!WARN_ON_ONCE(pos != req))
0201         hlist_nulls_del_init_rcu(&req->token_node);
0202     __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
0203     spin_unlock_bh(&bucket->lock);
0204 }
0205 
0206 bool mptcp_token_exists(u32 token)
0207 {
0208     struct hlist_nulls_node *pos;
0209     struct token_bucket *bucket;
0210     struct mptcp_sock *msk;
0211     struct sock *sk;
0212 
0213     rcu_read_lock();
0214     bucket = token_bucket(token);
0215 
0216 again:
0217     sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
0218         msk = mptcp_sk(sk);
0219         if (READ_ONCE(msk->token) == token)
0220             goto found;
0221     }
0222     if (get_nulls_value(pos) != (token & token_mask))
0223         goto again;
0224 
0225     rcu_read_unlock();
0226     return false;
0227 found:
0228     rcu_read_unlock();
0229     return true;
0230 }
0231 
0232 /**
0233  * mptcp_token_get_sock - retrieve mptcp connection sock using its token
0234  * @net: restrict to this namespace
0235  * @token: token of the mptcp connection to retrieve
0236  *
0237  * This function returns the mptcp connection structure with the given token.
0238  * A reference count on the mptcp socket returned is taken.
0239  *
0240  * returns NULL if no connection with the given token value exists.
0241  */
0242 struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token)
0243 {
0244     struct hlist_nulls_node *pos;
0245     struct token_bucket *bucket;
0246     struct mptcp_sock *msk;
0247     struct sock *sk;
0248 
0249     rcu_read_lock();
0250     bucket = token_bucket(token);
0251 
0252 again:
0253     sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
0254         msk = mptcp_sk(sk);
0255         if (READ_ONCE(msk->token) != token ||
0256             !net_eq(sock_net(sk), net))
0257             continue;
0258 
0259         if (!refcount_inc_not_zero(&sk->sk_refcnt))
0260             goto not_found;
0261 
0262         if (READ_ONCE(msk->token) != token ||
0263             !net_eq(sock_net(sk), net)) {
0264             sock_put(sk);
0265             goto again;
0266         }
0267         goto found;
0268     }
0269     if (get_nulls_value(pos) != (token & token_mask))
0270         goto again;
0271 
0272 not_found:
0273     msk = NULL;
0274 
0275 found:
0276     rcu_read_unlock();
0277     return msk;
0278 }
0279 EXPORT_SYMBOL_GPL(mptcp_token_get_sock);
0280 
0281 /**
0282  * mptcp_token_iter_next - iterate over the token container from given pos
0283  * @net: namespace to be iterated
0284  * @s_slot: start slot number
0285  * @s_num: start number inside the given lock
0286  *
0287  * This function returns the first mptcp connection structure found inside the
0288  * token container starting from the specified position, or NULL.
0289  *
0290  * On successful iteration, the iterator is move to the next position and the
0291  * the acquires a reference to the returned socket.
0292  */
0293 struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
0294                      long *s_num)
0295 {
0296     struct mptcp_sock *ret = NULL;
0297     struct hlist_nulls_node *pos;
0298     int slot, num = 0;
0299 
0300     for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) {
0301         struct token_bucket *bucket = &token_hash[slot];
0302         struct sock *sk;
0303 
0304         num = 0;
0305 
0306         if (hlist_nulls_empty(&bucket->msk_chain))
0307             continue;
0308 
0309         rcu_read_lock();
0310         sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
0311             ++num;
0312             if (!net_eq(sock_net(sk), net))
0313                 continue;
0314 
0315             if (num <= *s_num)
0316                 continue;
0317 
0318             if (!refcount_inc_not_zero(&sk->sk_refcnt))
0319                 continue;
0320 
0321             if (!net_eq(sock_net(sk), net)) {
0322                 sock_put(sk);
0323                 continue;
0324             }
0325 
0326             ret = mptcp_sk(sk);
0327             rcu_read_unlock();
0328             goto out;
0329         }
0330         rcu_read_unlock();
0331     }
0332 
0333 out:
0334     *s_slot = slot;
0335     *s_num = num;
0336     return ret;
0337 }
0338 EXPORT_SYMBOL_GPL(mptcp_token_iter_next);
0339 
0340 /**
0341  * mptcp_token_destroy_request - remove mptcp connection/token
0342  * @req: mptcp request socket dropping the token
0343  *
0344  * Remove the token associated to @req.
0345  */
0346 void mptcp_token_destroy_request(struct request_sock *req)
0347 {
0348     struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
0349     struct mptcp_subflow_request_sock *pos;
0350     struct token_bucket *bucket;
0351 
0352     if (hlist_nulls_unhashed(&subflow_req->token_node))
0353         return;
0354 
0355     bucket = token_bucket(subflow_req->token);
0356     spin_lock_bh(&bucket->lock);
0357     pos = __token_lookup_req(bucket, subflow_req->token);
0358     if (!WARN_ON_ONCE(pos != subflow_req)) {
0359         hlist_nulls_del_init_rcu(&pos->token_node);
0360         bucket->chain_len--;
0361     }
0362     spin_unlock_bh(&bucket->lock);
0363 }
0364 
0365 /**
0366  * mptcp_token_destroy - remove mptcp connection/token
0367  * @msk: mptcp connection dropping the token
0368  *
0369  * Remove the token associated to @msk
0370  */
0371 void mptcp_token_destroy(struct mptcp_sock *msk)
0372 {
0373     struct token_bucket *bucket;
0374     struct mptcp_sock *pos;
0375 
0376     if (sk_unhashed((struct sock *)msk))
0377         return;
0378 
0379     bucket = token_bucket(msk->token);
0380     spin_lock_bh(&bucket->lock);
0381     pos = __token_lookup_msk(bucket, msk->token);
0382     if (!WARN_ON_ONCE(pos != msk)) {
0383         __sk_nulls_del_node_init_rcu((struct sock *)pos);
0384         bucket->chain_len--;
0385     }
0386     spin_unlock_bh(&bucket->lock);
0387     WRITE_ONCE(msk->token, 0);
0388 }
0389 
0390 void __init mptcp_token_init(void)
0391 {
0392     int i;
0393 
0394     token_hash = alloc_large_system_hash("MPTCP token",
0395                          sizeof(struct token_bucket),
0396                          0,
0397                          20,/* one slot per 1MB of memory */
0398                          HASH_ZERO,
0399                          NULL,
0400                          &token_mask,
0401                          0,
0402                          64 * 1024);
0403     for (i = 0; i < token_mask + 1; ++i) {
0404         INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i);
0405         INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i);
0406         spin_lock_init(&token_hash[i].lock);
0407     }
0408 }
0409 
0410 #if IS_MODULE(CONFIG_MPTCP_KUNIT_TEST)
0411 EXPORT_SYMBOL_GPL(mptcp_token_new_request);
0412 EXPORT_SYMBOL_GPL(mptcp_token_new_connect);
0413 EXPORT_SYMBOL_GPL(mptcp_token_accept);
0414 EXPORT_SYMBOL_GPL(mptcp_token_destroy_request);
0415 EXPORT_SYMBOL_GPL(mptcp_token_destroy);
0416 #endif