Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Multipath TCP
0003  *
0004  * Copyright (c) 2020, Red Hat, Inc.
0005  */
0006 
0007 #define pr_fmt(fmt) "MPTCP: " fmt
0008 
0009 #include <linux/inet.h>
0010 #include <linux/kernel.h>
0011 #include <net/tcp.h>
0012 #include <net/netns/generic.h>
0013 #include <net/mptcp.h>
0014 #include <net/genetlink.h>
0015 #include <uapi/linux/mptcp.h>
0016 
0017 #include "protocol.h"
0018 #include "mib.h"
0019 
0020 /* forward declaration */
0021 static struct genl_family mptcp_genl_family;
0022 
0023 static int pm_nl_pernet_id;
0024 
0025 struct mptcp_pm_add_entry {
0026     struct list_head    list;
0027     struct mptcp_addr_info  addr;
0028     struct timer_list   add_timer;
0029     struct mptcp_sock   *sock;
0030     u8          retrans_times;
0031 };
0032 
0033 struct pm_nl_pernet {
0034     /* protects pernet updates */
0035     spinlock_t      lock;
0036     struct list_head    local_addr_list;
0037     unsigned int        addrs;
0038     unsigned int        stale_loss_cnt;
0039     unsigned int        add_addr_signal_max;
0040     unsigned int        add_addr_accept_max;
0041     unsigned int        local_addr_max;
0042     unsigned int        subflows_max;
0043     unsigned int        next_id;
0044     DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
0045 };
0046 
0047 #define MPTCP_PM_ADDR_MAX   8
0048 #define ADD_ADDR_RETRANS_MAX    3
0049 
0050 static struct pm_nl_pernet *pm_nl_get_pernet(const struct net *net)
0051 {
0052     return net_generic(net, pm_nl_pernet_id);
0053 }
0054 
0055 static struct pm_nl_pernet *
0056 pm_nl_get_pernet_from_msk(const struct mptcp_sock *msk)
0057 {
0058     return pm_nl_get_pernet(sock_net((struct sock *)msk));
0059 }
0060 
0061 bool mptcp_addresses_equal(const struct mptcp_addr_info *a,
0062                const struct mptcp_addr_info *b, bool use_port)
0063 {
0064     bool addr_equals = false;
0065 
0066     if (a->family == b->family) {
0067         if (a->family == AF_INET)
0068             addr_equals = a->addr.s_addr == b->addr.s_addr;
0069 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
0070         else
0071             addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
0072     } else if (a->family == AF_INET) {
0073         if (ipv6_addr_v4mapped(&b->addr6))
0074             addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
0075     } else if (b->family == AF_INET) {
0076         if (ipv6_addr_v4mapped(&a->addr6))
0077             addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
0078 #endif
0079     }
0080 
0081     if (!addr_equals)
0082         return false;
0083     if (!use_port)
0084         return true;
0085 
0086     return a->port == b->port;
0087 }
0088 
0089 static void local_address(const struct sock_common *skc,
0090               struct mptcp_addr_info *addr)
0091 {
0092     addr->family = skc->skc_family;
0093     addr->port = htons(skc->skc_num);
0094     if (addr->family == AF_INET)
0095         addr->addr.s_addr = skc->skc_rcv_saddr;
0096 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
0097     else if (addr->family == AF_INET6)
0098         addr->addr6 = skc->skc_v6_rcv_saddr;
0099 #endif
0100 }
0101 
0102 static void remote_address(const struct sock_common *skc,
0103                struct mptcp_addr_info *addr)
0104 {
0105     addr->family = skc->skc_family;
0106     addr->port = skc->skc_dport;
0107     if (addr->family == AF_INET)
0108         addr->addr.s_addr = skc->skc_daddr;
0109 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
0110     else if (addr->family == AF_INET6)
0111         addr->addr6 = skc->skc_v6_daddr;
0112 #endif
0113 }
0114 
0115 static bool lookup_subflow_by_saddr(const struct list_head *list,
0116                     const struct mptcp_addr_info *saddr)
0117 {
0118     struct mptcp_subflow_context *subflow;
0119     struct mptcp_addr_info cur;
0120     struct sock_common *skc;
0121 
0122     list_for_each_entry(subflow, list, node) {
0123         skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
0124 
0125         local_address(skc, &cur);
0126         if (mptcp_addresses_equal(&cur, saddr, saddr->port))
0127             return true;
0128     }
0129 
0130     return false;
0131 }
0132 
0133 static bool lookup_subflow_by_daddr(const struct list_head *list,
0134                     const struct mptcp_addr_info *daddr)
0135 {
0136     struct mptcp_subflow_context *subflow;
0137     struct mptcp_addr_info cur;
0138     struct sock_common *skc;
0139 
0140     list_for_each_entry(subflow, list, node) {
0141         skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
0142 
0143         remote_address(skc, &cur);
0144         if (mptcp_addresses_equal(&cur, daddr, daddr->port))
0145             return true;
0146     }
0147 
0148     return false;
0149 }
0150 
0151 static struct mptcp_pm_addr_entry *
0152 select_local_address(const struct pm_nl_pernet *pernet,
0153              const struct mptcp_sock *msk)
0154 {
0155     const struct sock *sk = (const struct sock *)msk;
0156     struct mptcp_pm_addr_entry *entry, *ret = NULL;
0157 
0158     msk_owned_by_me(msk);
0159 
0160     rcu_read_lock();
0161     list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
0162         if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
0163             continue;
0164 
0165         if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
0166             continue;
0167 
0168         if (entry->addr.family != sk->sk_family) {
0169 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
0170             if ((entry->addr.family == AF_INET &&
0171                  !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
0172                 (sk->sk_family == AF_INET &&
0173                  !ipv6_addr_v4mapped(&entry->addr.addr6)))
0174 #endif
0175                 continue;
0176         }
0177 
0178         ret = entry;
0179         break;
0180     }
0181     rcu_read_unlock();
0182     return ret;
0183 }
0184 
0185 static struct mptcp_pm_addr_entry *
0186 select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk)
0187 {
0188     struct mptcp_pm_addr_entry *entry, *ret = NULL;
0189 
0190     rcu_read_lock();
0191     /* do not keep any additional per socket state, just signal
0192      * the address list in order.
0193      * Note: removal from the local address list during the msk life-cycle
0194      * can lead to additional addresses not being announced.
0195      */
0196     list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
0197         if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
0198             continue;
0199 
0200         if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
0201             continue;
0202 
0203         ret = entry;
0204         break;
0205     }
0206     rcu_read_unlock();
0207     return ret;
0208 }
0209 
0210 unsigned int mptcp_pm_get_add_addr_signal_max(const struct mptcp_sock *msk)
0211 {
0212     const struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
0213 
0214     return READ_ONCE(pernet->add_addr_signal_max);
0215 }
0216 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
0217 
0218 unsigned int mptcp_pm_get_add_addr_accept_max(const struct mptcp_sock *msk)
0219 {
0220     struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
0221 
0222     return READ_ONCE(pernet->add_addr_accept_max);
0223 }
0224 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
0225 
0226 unsigned int mptcp_pm_get_subflows_max(const struct mptcp_sock *msk)
0227 {
0228     struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
0229 
0230     return READ_ONCE(pernet->subflows_max);
0231 }
0232 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
0233 
0234 unsigned int mptcp_pm_get_local_addr_max(const struct mptcp_sock *msk)
0235 {
0236     struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
0237 
0238     return READ_ONCE(pernet->local_addr_max);
0239 }
0240 EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max);
0241 
0242 bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk)
0243 {
0244     struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk);
0245 
0246     if (msk->pm.subflows == mptcp_pm_get_subflows_max(msk) ||
0247         (find_next_and_bit(pernet->id_bitmap, msk->pm.id_avail_bitmap,
0248                    MPTCP_PM_MAX_ADDR_ID + 1, 0) == MPTCP_PM_MAX_ADDR_ID + 1)) {
0249         WRITE_ONCE(msk->pm.work_pending, false);
0250         return false;
0251     }
0252     return true;
0253 }
0254 
0255 struct mptcp_pm_add_entry *
0256 mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk,
0257                 const struct mptcp_addr_info *addr)
0258 {
0259     struct mptcp_pm_add_entry *entry;
0260 
0261     lockdep_assert_held(&msk->pm.lock);
0262 
0263     list_for_each_entry(entry, &msk->pm.anno_list, list) {
0264         if (mptcp_addresses_equal(&entry->addr, addr, true))
0265             return entry;
0266     }
0267 
0268     return NULL;
0269 }
0270 
0271 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
0272 {
0273     struct mptcp_pm_add_entry *entry;
0274     struct mptcp_addr_info saddr;
0275     bool ret = false;
0276 
0277     local_address((struct sock_common *)sk, &saddr);
0278 
0279     spin_lock_bh(&msk->pm.lock);
0280     list_for_each_entry(entry, &msk->pm.anno_list, list) {
0281         if (mptcp_addresses_equal(&entry->addr, &saddr, true)) {
0282             ret = true;
0283             goto out;
0284         }
0285     }
0286 
0287 out:
0288     spin_unlock_bh(&msk->pm.lock);
0289     return ret;
0290 }
0291 
0292 static void mptcp_pm_add_timer(struct timer_list *timer)
0293 {
0294     struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
0295     struct mptcp_sock *msk = entry->sock;
0296     struct sock *sk = (struct sock *)msk;
0297 
0298     pr_debug("msk=%p", msk);
0299 
0300     if (!msk)
0301         return;
0302 
0303     if (inet_sk_state_load(sk) == TCP_CLOSE)
0304         return;
0305 
0306     if (!entry->addr.id)
0307         return;
0308 
0309     if (mptcp_pm_should_add_signal_addr(msk)) {
0310         sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
0311         goto out;
0312     }
0313 
0314     spin_lock_bh(&msk->pm.lock);
0315 
0316     if (!mptcp_pm_should_add_signal_addr(msk)) {
0317         pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
0318         mptcp_pm_announce_addr(msk, &entry->addr, false);
0319         mptcp_pm_add_addr_send_ack(msk);
0320         entry->retrans_times++;
0321     }
0322 
0323     if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
0324         sk_reset_timer(sk, timer,
0325                    jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
0326 
0327     spin_unlock_bh(&msk->pm.lock);
0328 
0329     if (entry->retrans_times == ADD_ADDR_RETRANS_MAX)
0330         mptcp_pm_subflow_established(msk);
0331 
0332 out:
0333     __sock_put(sk);
0334 }
0335 
0336 struct mptcp_pm_add_entry *
0337 mptcp_pm_del_add_timer(struct mptcp_sock *msk,
0338                const struct mptcp_addr_info *addr, bool check_id)
0339 {
0340     struct mptcp_pm_add_entry *entry;
0341     struct sock *sk = (struct sock *)msk;
0342 
0343     spin_lock_bh(&msk->pm.lock);
0344     entry = mptcp_lookup_anno_list_by_saddr(msk, addr);
0345     if (entry && (!check_id || entry->addr.id == addr->id))
0346         entry->retrans_times = ADD_ADDR_RETRANS_MAX;
0347     spin_unlock_bh(&msk->pm.lock);
0348 
0349     if (entry && (!check_id || entry->addr.id == addr->id))
0350         sk_stop_timer_sync(sk, &entry->add_timer);
0351 
0352     return entry;
0353 }
0354 
0355 bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
0356                   const struct mptcp_pm_addr_entry *entry)
0357 {
0358     struct mptcp_pm_add_entry *add_entry = NULL;
0359     struct sock *sk = (struct sock *)msk;
0360     struct net *net = sock_net(sk);
0361 
0362     lockdep_assert_held(&msk->pm.lock);
0363 
0364     add_entry = mptcp_lookup_anno_list_by_saddr(msk, &entry->addr);
0365 
0366     if (add_entry) {
0367         if (mptcp_pm_is_kernel(msk))
0368             return false;
0369 
0370         sk_reset_timer(sk, &add_entry->add_timer,
0371                    jiffies + mptcp_get_add_addr_timeout(net));
0372         return true;
0373     }
0374 
0375     add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
0376     if (!add_entry)
0377         return false;
0378 
0379     list_add(&add_entry->list, &msk->pm.anno_list);
0380 
0381     add_entry->addr = entry->addr;
0382     add_entry->sock = msk;
0383     add_entry->retrans_times = 0;
0384 
0385     timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
0386     sk_reset_timer(sk, &add_entry->add_timer,
0387                jiffies + mptcp_get_add_addr_timeout(net));
0388 
0389     return true;
0390 }
0391 
0392 void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
0393 {
0394     struct mptcp_pm_add_entry *entry, *tmp;
0395     struct sock *sk = (struct sock *)msk;
0396     LIST_HEAD(free_list);
0397 
0398     pr_debug("msk=%p", msk);
0399 
0400     spin_lock_bh(&msk->pm.lock);
0401     list_splice_init(&msk->pm.anno_list, &free_list);
0402     spin_unlock_bh(&msk->pm.lock);
0403 
0404     list_for_each_entry_safe(entry, tmp, &free_list, list) {
0405         sk_stop_timer_sync(sk, &entry->add_timer);
0406         kfree(entry);
0407     }
0408 }
0409 
0410 static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned int nr,
0411                   const struct mptcp_addr_info *addr)
0412 {
0413     int i;
0414 
0415     for (i = 0; i < nr; i++) {
0416         if (addrs[i].id == addr->id)
0417             return true;
0418     }
0419 
0420     return false;
0421 }
0422 
0423 /* Fill all the remote addresses into the array addrs[],
0424  * and return the array size.
0425  */
0426 static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullmesh,
0427                           struct mptcp_addr_info *addrs)
0428 {
0429     bool deny_id0 = READ_ONCE(msk->pm.remote_deny_join_id0);
0430     struct sock *sk = (struct sock *)msk, *ssk;
0431     struct mptcp_subflow_context *subflow;
0432     struct mptcp_addr_info remote = { 0 };
0433     unsigned int subflows_max;
0434     int i = 0;
0435 
0436     subflows_max = mptcp_pm_get_subflows_max(msk);
0437     remote_address((struct sock_common *)sk, &remote);
0438 
0439     /* Non-fullmesh endpoint, fill in the single entry
0440      * corresponding to the primary MPC subflow remote address
0441      */
0442     if (!fullmesh) {
0443         if (deny_id0)
0444             return 0;
0445 
0446         msk->pm.subflows++;
0447         addrs[i++] = remote;
0448     } else {
0449         mptcp_for_each_subflow(msk, subflow) {
0450             ssk = mptcp_subflow_tcp_sock(subflow);
0451             remote_address((struct sock_common *)ssk, &addrs[i]);
0452             addrs[i].id = subflow->remote_id;
0453             if (deny_id0 && !addrs[i].id)
0454                 continue;
0455 
0456             if (!lookup_address_in_vec(addrs, i, &addrs[i]) &&
0457                 msk->pm.subflows < subflows_max) {
0458                 msk->pm.subflows++;
0459                 i++;
0460             }
0461         }
0462     }
0463 
0464     return i;
0465 }
0466 
0467 static void __mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow,
0468                 bool prio, bool backup)
0469 {
0470     struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
0471     bool slow;
0472 
0473     pr_debug("send ack for %s",
0474          prio ? "mp_prio" : (mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr"));
0475 
0476     slow = lock_sock_fast(ssk);
0477     if (prio) {
0478         if (subflow->backup != backup)
0479             msk->last_snd = NULL;
0480 
0481         subflow->send_mp_prio = 1;
0482         subflow->backup = backup;
0483         subflow->request_bkup = backup;
0484     }
0485 
0486     __mptcp_subflow_send_ack(ssk);
0487     unlock_sock_fast(ssk, slow);
0488 }
0489 
0490 static void mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow,
0491                   bool prio, bool backup)
0492 {
0493     spin_unlock_bh(&msk->pm.lock);
0494     __mptcp_pm_send_ack(msk, subflow, prio, backup);
0495     spin_lock_bh(&msk->pm.lock);
0496 }
0497 
0498 static struct mptcp_pm_addr_entry *
0499 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
0500 {
0501     struct mptcp_pm_addr_entry *entry;
0502 
0503     list_for_each_entry(entry, &pernet->local_addr_list, list) {
0504         if (entry->addr.id == id)
0505             return entry;
0506     }
0507     return NULL;
0508 }
0509 
0510 static struct mptcp_pm_addr_entry *
0511 __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info,
0512           bool lookup_by_id)
0513 {
0514     struct mptcp_pm_addr_entry *entry;
0515 
0516     list_for_each_entry(entry, &pernet->local_addr_list, list) {
0517         if ((!lookup_by_id &&
0518              mptcp_addresses_equal(&entry->addr, info, entry->addr.port)) ||
0519             (lookup_by_id && entry->addr.id == info->id))
0520             return entry;
0521     }
0522     return NULL;
0523 }
0524 
0525 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
0526 {
0527     struct sock *sk = (struct sock *)msk;
0528     struct mptcp_pm_addr_entry *local;
0529     unsigned int add_addr_signal_max;
0530     unsigned int local_addr_max;
0531     struct pm_nl_pernet *pernet;
0532     unsigned int subflows_max;
0533 
0534     pernet = pm_nl_get_pernet(sock_net(sk));
0535 
0536     add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
0537     local_addr_max = mptcp_pm_get_local_addr_max(msk);
0538     subflows_max = mptcp_pm_get_subflows_max(msk);
0539 
0540     /* do lazy endpoint usage accounting for the MPC subflows */
0541     if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) {
0542         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(msk->first);
0543         struct mptcp_pm_addr_entry *entry;
0544         struct mptcp_addr_info mpc_addr;
0545         bool backup = false;
0546 
0547         local_address((struct sock_common *)msk->first, &mpc_addr);
0548         rcu_read_lock();
0549         entry = __lookup_addr(pernet, &mpc_addr, false);
0550         if (entry) {
0551             __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap);
0552             msk->mpc_endpoint_id = entry->addr.id;
0553             backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP);
0554         }
0555         rcu_read_unlock();
0556 
0557         if (backup)
0558             mptcp_pm_send_ack(msk, subflow, true, backup);
0559 
0560         msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED);
0561     }
0562 
0563     pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
0564          msk->pm.local_addr_used, local_addr_max,
0565          msk->pm.add_addr_signaled, add_addr_signal_max,
0566          msk->pm.subflows, subflows_max);
0567 
0568     /* check first for announce */
0569     if (msk->pm.add_addr_signaled < add_addr_signal_max) {
0570         local = select_signal_address(pernet, msk);
0571 
0572         /* due to racing events on both ends we can reach here while
0573          * previous add address is still running: if we invoke now
0574          * mptcp_pm_announce_addr(), that will fail and the
0575          * corresponding id will be marked as used.
0576          * Instead let the PM machinery reschedule us when the
0577          * current address announce will be completed.
0578          */
0579         if (msk->pm.addr_signal & BIT(MPTCP_ADD_ADDR_SIGNAL))
0580             return;
0581 
0582         if (local) {
0583             if (mptcp_pm_alloc_anno_list(msk, local)) {
0584                 __clear_bit(local->addr.id, msk->pm.id_avail_bitmap);
0585                 msk->pm.add_addr_signaled++;
0586                 mptcp_pm_announce_addr(msk, &local->addr, false);
0587                 mptcp_pm_nl_addr_send_ack(msk);
0588             }
0589         }
0590     }
0591 
0592     /* check if should create a new subflow */
0593     while (msk->pm.local_addr_used < local_addr_max &&
0594            msk->pm.subflows < subflows_max) {
0595         struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
0596         bool fullmesh;
0597         int i, nr;
0598 
0599         local = select_local_address(pernet, msk);
0600         if (!local)
0601             break;
0602 
0603         fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH);
0604 
0605         msk->pm.local_addr_used++;
0606         nr = fill_remote_addresses_vec(msk, fullmesh, addrs);
0607         if (nr)
0608             __clear_bit(local->addr.id, msk->pm.id_avail_bitmap);
0609         spin_unlock_bh(&msk->pm.lock);
0610         for (i = 0; i < nr; i++)
0611             __mptcp_subflow_connect(sk, &local->addr, &addrs[i]);
0612         spin_lock_bh(&msk->pm.lock);
0613     }
0614     mptcp_pm_nl_check_work_pending(msk);
0615 }
0616 
0617 static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
0618 {
0619     mptcp_pm_create_subflow_or_signal_addr(msk);
0620 }
0621 
0622 static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
0623 {
0624     mptcp_pm_create_subflow_or_signal_addr(msk);
0625 }
0626 
0627 /* Fill all the local addresses into the array addrs[],
0628  * and return the array size.
0629  */
0630 static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
0631                          struct mptcp_addr_info *addrs)
0632 {
0633     struct sock *sk = (struct sock *)msk;
0634     struct mptcp_pm_addr_entry *entry;
0635     struct mptcp_addr_info local;
0636     struct pm_nl_pernet *pernet;
0637     unsigned int subflows_max;
0638     int i = 0;
0639 
0640     pernet = pm_nl_get_pernet_from_msk(msk);
0641     subflows_max = mptcp_pm_get_subflows_max(msk);
0642 
0643     rcu_read_lock();
0644     list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
0645         if (!(entry->flags & MPTCP_PM_ADDR_FLAG_FULLMESH))
0646             continue;
0647 
0648         if (entry->addr.family != sk->sk_family) {
0649 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
0650             if ((entry->addr.family == AF_INET &&
0651                  !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
0652                 (sk->sk_family == AF_INET &&
0653                  !ipv6_addr_v4mapped(&entry->addr.addr6)))
0654 #endif
0655                 continue;
0656         }
0657 
0658         if (msk->pm.subflows < subflows_max) {
0659             msk->pm.subflows++;
0660             addrs[i++] = entry->addr;
0661         }
0662     }
0663     rcu_read_unlock();
0664 
0665     /* If the array is empty, fill in the single
0666      * 'IPADDRANY' local address
0667      */
0668     if (!i) {
0669         memset(&local, 0, sizeof(local));
0670         local.family = msk->pm.remote.family;
0671 
0672         msk->pm.subflows++;
0673         addrs[i++] = local;
0674     }
0675 
0676     return i;
0677 }
0678 
0679 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
0680 {
0681     struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
0682     struct sock *sk = (struct sock *)msk;
0683     unsigned int add_addr_accept_max;
0684     struct mptcp_addr_info remote;
0685     unsigned int subflows_max;
0686     int i, nr;
0687 
0688     add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
0689     subflows_max = mptcp_pm_get_subflows_max(msk);
0690 
0691     pr_debug("accepted %d:%d remote family %d",
0692          msk->pm.add_addr_accepted, add_addr_accept_max,
0693          msk->pm.remote.family);
0694 
0695     remote = msk->pm.remote;
0696     mptcp_pm_announce_addr(msk, &remote, true);
0697     mptcp_pm_nl_addr_send_ack(msk);
0698 
0699     if (lookup_subflow_by_daddr(&msk->conn_list, &remote))
0700         return;
0701 
0702     /* pick id 0 port, if none is provided the remote address */
0703     if (!remote.port)
0704         remote.port = sk->sk_dport;
0705 
0706     /* connect to the specified remote address, using whatever
0707      * local address the routing configuration will pick.
0708      */
0709     nr = fill_local_addresses_vec(msk, addrs);
0710 
0711     msk->pm.add_addr_accepted++;
0712     if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
0713         msk->pm.subflows >= subflows_max)
0714         WRITE_ONCE(msk->pm.accept_addr, false);
0715 
0716     spin_unlock_bh(&msk->pm.lock);
0717     for (i = 0; i < nr; i++)
0718         __mptcp_subflow_connect(sk, &addrs[i], &remote);
0719     spin_lock_bh(&msk->pm.lock);
0720 }
0721 
0722 void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
0723 {
0724     struct mptcp_subflow_context *subflow;
0725 
0726     msk_owned_by_me(msk);
0727     lockdep_assert_held(&msk->pm.lock);
0728 
0729     if (!mptcp_pm_should_add_signal(msk) &&
0730         !mptcp_pm_should_rm_signal(msk))
0731         return;
0732 
0733     subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
0734     if (subflow)
0735         mptcp_pm_send_ack(msk, subflow, false, false);
0736 }
0737 
0738 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
0739                  struct mptcp_addr_info *addr,
0740                  struct mptcp_addr_info *rem,
0741                  u8 bkup)
0742 {
0743     struct mptcp_subflow_context *subflow;
0744 
0745     pr_debug("bkup=%d", bkup);
0746 
0747     mptcp_for_each_subflow(msk, subflow) {
0748         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
0749         struct mptcp_addr_info local, remote;
0750 
0751         local_address((struct sock_common *)ssk, &local);
0752         if (!mptcp_addresses_equal(&local, addr, addr->port))
0753             continue;
0754 
0755         if (rem && rem->family != AF_UNSPEC) {
0756             remote_address((struct sock_common *)ssk, &remote);
0757             if (!mptcp_addresses_equal(&remote, rem, rem->port))
0758                 continue;
0759         }
0760 
0761         __mptcp_pm_send_ack(msk, subflow, true, bkup);
0762         return 0;
0763     }
0764 
0765     return -EINVAL;
0766 }
0767 
0768 static bool mptcp_local_id_match(const struct mptcp_sock *msk, u8 local_id, u8 id)
0769 {
0770     return local_id == id || (!local_id && msk->mpc_endpoint_id == id);
0771 }
0772 
0773 static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
0774                        const struct mptcp_rm_list *rm_list,
0775                        enum linux_mptcp_mib_field rm_type)
0776 {
0777     struct mptcp_subflow_context *subflow, *tmp;
0778     struct sock *sk = (struct sock *)msk;
0779     u8 i;
0780 
0781     pr_debug("%s rm_list_nr %d",
0782          rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr);
0783 
0784     msk_owned_by_me(msk);
0785 
0786     if (sk->sk_state == TCP_LISTEN)
0787         return;
0788 
0789     if (!rm_list->nr)
0790         return;
0791 
0792     if (list_empty(&msk->conn_list))
0793         return;
0794 
0795     for (i = 0; i < rm_list->nr; i++) {
0796         u8 rm_id = rm_list->ids[i];
0797         bool removed = false;
0798 
0799         list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
0800             struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
0801             int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
0802             u8 id = subflow->local_id;
0803 
0804             if (rm_type == MPTCP_MIB_RMADDR && subflow->remote_id != rm_id)
0805                 continue;
0806             if (rm_type == MPTCP_MIB_RMSUBFLOW && !mptcp_local_id_match(msk, id, rm_id))
0807                 continue;
0808 
0809             pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u",
0810                  rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
0811                  i, rm_id, subflow->local_id, subflow->remote_id,
0812                  msk->mpc_endpoint_id);
0813             spin_unlock_bh(&msk->pm.lock);
0814             mptcp_subflow_shutdown(sk, ssk, how);
0815 
0816             /* the following takes care of updating the subflows counter */
0817             mptcp_close_ssk(sk, ssk, subflow);
0818             spin_lock_bh(&msk->pm.lock);
0819 
0820             removed = true;
0821             __MPTCP_INC_STATS(sock_net(sk), rm_type);
0822         }
0823         if (rm_type == MPTCP_MIB_RMSUBFLOW)
0824             __set_bit(rm_id ? rm_id : msk->mpc_endpoint_id, msk->pm.id_avail_bitmap);
0825         if (!removed)
0826             continue;
0827 
0828         if (!mptcp_pm_is_kernel(msk))
0829             continue;
0830 
0831         if (rm_type == MPTCP_MIB_RMADDR) {
0832             msk->pm.add_addr_accepted--;
0833             WRITE_ONCE(msk->pm.accept_addr, true);
0834         } else if (rm_type == MPTCP_MIB_RMSUBFLOW) {
0835             msk->pm.local_addr_used--;
0836         }
0837     }
0838 }
0839 
0840 static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
0841 {
0842     mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR);
0843 }
0844 
0845 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
0846                      const struct mptcp_rm_list *rm_list)
0847 {
0848     mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW);
0849 }
0850 
0851 void mptcp_pm_nl_work(struct mptcp_sock *msk)
0852 {
0853     struct mptcp_pm_data *pm = &msk->pm;
0854 
0855     msk_owned_by_me(msk);
0856 
0857     if (!(pm->status & MPTCP_PM_WORK_MASK))
0858         return;
0859 
0860     spin_lock_bh(&msk->pm.lock);
0861 
0862     pr_debug("msk=%p status=%x", msk, pm->status);
0863     if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
0864         pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
0865         mptcp_pm_nl_add_addr_received(msk);
0866     }
0867     if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
0868         pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
0869         mptcp_pm_nl_addr_send_ack(msk);
0870     }
0871     if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
0872         pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
0873         mptcp_pm_nl_rm_addr_received(msk);
0874     }
0875     if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
0876         pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
0877         mptcp_pm_nl_fully_established(msk);
0878     }
0879     if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
0880         pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
0881         mptcp_pm_nl_subflow_established(msk);
0882     }
0883 
0884     spin_unlock_bh(&msk->pm.lock);
0885 }
0886 
0887 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
0888 {
0889     return (entry->flags &
0890         (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
0891         MPTCP_PM_ADDR_FLAG_SIGNAL;
0892 }
0893 
0894 /* caller must ensure the RCU grace period is already elapsed */
0895 static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry)
0896 {
0897     if (entry->lsk)
0898         sock_release(entry->lsk);
0899     kfree(entry);
0900 }
0901 
0902 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
0903                          struct mptcp_pm_addr_entry *entry)
0904 {
0905     struct mptcp_pm_addr_entry *cur, *del_entry = NULL;
0906     unsigned int addr_max;
0907     int ret = -EINVAL;
0908 
0909     spin_lock_bh(&pernet->lock);
0910     /* to keep the code simple, don't do IDR-like allocation for address ID,
0911      * just bail when we exceed limits
0912      */
0913     if (pernet->next_id == MPTCP_PM_MAX_ADDR_ID)
0914         pernet->next_id = 1;
0915     if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
0916         goto out;
0917     if (test_bit(entry->addr.id, pernet->id_bitmap))
0918         goto out;
0919 
0920     /* do not insert duplicate address, differentiate on port only
0921      * singled addresses
0922      */
0923     if (!address_use_port(entry))
0924         entry->addr.port = 0;
0925     list_for_each_entry(cur, &pernet->local_addr_list, list) {
0926         if (mptcp_addresses_equal(&cur->addr, &entry->addr,
0927                       cur->addr.port || entry->addr.port)) {
0928             /* allow replacing the exiting endpoint only if such
0929              * endpoint is an implicit one and the user-space
0930              * did not provide an endpoint id
0931              */
0932             if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT))
0933                 goto out;
0934             if (entry->addr.id)
0935                 goto out;
0936 
0937             pernet->addrs--;
0938             entry->addr.id = cur->addr.id;
0939             list_del_rcu(&cur->list);
0940             del_entry = cur;
0941             break;
0942         }
0943     }
0944 
0945     if (!entry->addr.id) {
0946 find_next:
0947         entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
0948                             MPTCP_PM_MAX_ADDR_ID + 1,
0949                             pernet->next_id);
0950         if (!entry->addr.id && pernet->next_id != 1) {
0951             pernet->next_id = 1;
0952             goto find_next;
0953         }
0954     }
0955 
0956     if (!entry->addr.id)
0957         goto out;
0958 
0959     __set_bit(entry->addr.id, pernet->id_bitmap);
0960     if (entry->addr.id > pernet->next_id)
0961         pernet->next_id = entry->addr.id;
0962 
0963     if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
0964         addr_max = pernet->add_addr_signal_max;
0965         WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
0966     }
0967     if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
0968         addr_max = pernet->local_addr_max;
0969         WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
0970     }
0971 
0972     pernet->addrs++;
0973     if (!entry->addr.port)
0974         list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
0975     else
0976         list_add_rcu(&entry->list, &pernet->local_addr_list);
0977     ret = entry->addr.id;
0978 
0979 out:
0980     spin_unlock_bh(&pernet->lock);
0981 
0982     /* just replaced an existing entry, free it */
0983     if (del_entry) {
0984         synchronize_rcu();
0985         __mptcp_pm_release_addr_entry(del_entry);
0986     }
0987     return ret;
0988 }
0989 
0990 static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
0991                         struct mptcp_pm_addr_entry *entry)
0992 {
0993     int addrlen = sizeof(struct sockaddr_in);
0994     struct sockaddr_storage addr;
0995     struct mptcp_sock *msk;
0996     struct socket *ssock;
0997     int backlog = 1024;
0998     int err;
0999 
1000     err = sock_create_kern(sock_net(sk), entry->addr.family,
1001                    SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
1002     if (err)
1003         return err;
1004 
1005     msk = mptcp_sk(entry->lsk->sk);
1006     if (!msk) {
1007         err = -EINVAL;
1008         goto out;
1009     }
1010 
1011     ssock = __mptcp_nmpc_socket(msk);
1012     if (!ssock) {
1013         err = -EINVAL;
1014         goto out;
1015     }
1016 
1017     mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
1018 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1019     if (entry->addr.family == AF_INET6)
1020         addrlen = sizeof(struct sockaddr_in6);
1021 #endif
1022     err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen);
1023     if (err) {
1024         pr_warn("kernel_bind error, err=%d", err);
1025         goto out;
1026     }
1027 
1028     err = kernel_listen(ssock, backlog);
1029     if (err) {
1030         pr_warn("kernel_listen error, err=%d", err);
1031         goto out;
1032     }
1033 
1034     return 0;
1035 
1036 out:
1037     sock_release(entry->lsk);
1038     return err;
1039 }
1040 
1041 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
1042 {
1043     struct mptcp_pm_addr_entry *entry;
1044     struct mptcp_addr_info skc_local;
1045     struct mptcp_addr_info msk_local;
1046     struct pm_nl_pernet *pernet;
1047     int ret = -1;
1048 
1049     if (WARN_ON_ONCE(!msk))
1050         return -1;
1051 
1052     /* The 0 ID mapping is defined by the first subflow, copied into the msk
1053      * addr
1054      */
1055     local_address((struct sock_common *)msk, &msk_local);
1056     local_address((struct sock_common *)skc, &skc_local);
1057     if (mptcp_addresses_equal(&msk_local, &skc_local, false))
1058         return 0;
1059 
1060     if (mptcp_pm_is_userspace(msk))
1061         return mptcp_userspace_pm_get_local_id(msk, &skc_local);
1062 
1063     pernet = pm_nl_get_pernet_from_msk(msk);
1064 
1065     rcu_read_lock();
1066     list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
1067         if (mptcp_addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
1068             ret = entry->addr.id;
1069             break;
1070         }
1071     }
1072     rcu_read_unlock();
1073     if (ret >= 0)
1074         return ret;
1075 
1076     /* address not found, add to local list */
1077     entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
1078     if (!entry)
1079         return -ENOMEM;
1080 
1081     entry->addr = skc_local;
1082     entry->addr.id = 0;
1083     entry->addr.port = 0;
1084     entry->ifindex = 0;
1085     entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
1086     entry->lsk = NULL;
1087     ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1088     if (ret < 0)
1089         kfree(entry);
1090 
1091     return ret;
1092 }
1093 
1094 #define MPTCP_PM_CMD_GRP_OFFSET       0
1095 #define MPTCP_PM_EV_GRP_OFFSET        1
1096 
1097 static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
1098     [MPTCP_PM_CMD_GRP_OFFSET]   = { .name = MPTCP_PM_CMD_GRP_NAME, },
1099     [MPTCP_PM_EV_GRP_OFFSET]        = { .name = MPTCP_PM_EV_GRP_NAME,
1100                         .flags = GENL_UNS_ADMIN_PERM,
1101                       },
1102 };
1103 
1104 static const struct nla_policy
1105 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
1106     [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type   = NLA_U16,  },
1107     [MPTCP_PM_ADDR_ATTR_ID]     = { .type   = NLA_U8,   },
1108     [MPTCP_PM_ADDR_ATTR_ADDR4]  = { .type   = NLA_U32,  },
1109     [MPTCP_PM_ADDR_ATTR_ADDR6]  =
1110         NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
1111     [MPTCP_PM_ADDR_ATTR_PORT]   = { .type   = NLA_U16   },
1112     [MPTCP_PM_ADDR_ATTR_FLAGS]  = { .type   = NLA_U32   },
1113     [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type   = NLA_S32   },
1114 };
1115 
1116 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
1117     [MPTCP_PM_ATTR_ADDR]        =
1118                     NLA_POLICY_NESTED(mptcp_pm_addr_policy),
1119     [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type   = NLA_U32,  },
1120     [MPTCP_PM_ATTR_SUBFLOWS]    = { .type   = NLA_U32,  },
1121     [MPTCP_PM_ATTR_TOKEN]       = { .type   = NLA_U32,  },
1122     [MPTCP_PM_ATTR_LOC_ID]      = { .type   = NLA_U8,   },
1123     [MPTCP_PM_ATTR_ADDR_REMOTE] =
1124                     NLA_POLICY_NESTED(mptcp_pm_addr_policy),
1125 };
1126 
1127 void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk)
1128 {
1129     struct mptcp_subflow_context *iter, *subflow = mptcp_subflow_ctx(ssk);
1130     struct sock *sk = (struct sock *)msk;
1131     unsigned int active_max_loss_cnt;
1132     struct net *net = sock_net(sk);
1133     unsigned int stale_loss_cnt;
1134     bool slow;
1135 
1136     stale_loss_cnt = mptcp_stale_loss_cnt(net);
1137     if (subflow->stale || !stale_loss_cnt || subflow->stale_count <= stale_loss_cnt)
1138         return;
1139 
1140     /* look for another available subflow not in loss state */
1141     active_max_loss_cnt = max_t(int, stale_loss_cnt - 1, 1);
1142     mptcp_for_each_subflow(msk, iter) {
1143         if (iter != subflow && mptcp_subflow_active(iter) &&
1144             iter->stale_count < active_max_loss_cnt) {
1145             /* we have some alternatives, try to mark this subflow as idle ...*/
1146             slow = lock_sock_fast(ssk);
1147             if (!tcp_rtx_and_write_queues_empty(ssk)) {
1148                 subflow->stale = 1;
1149                 __mptcp_retransmit_pending_data(sk);
1150                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE);
1151             }
1152             unlock_sock_fast(ssk, slow);
1153 
1154             /* always try to push the pending data regardless of re-injections:
1155              * we can possibly use backup subflows now, and subflow selection
1156              * is cheap under the msk socket lock
1157              */
1158             __mptcp_push_pending(sk, 0);
1159             return;
1160         }
1161     }
1162 }
1163 
1164 static int mptcp_pm_family_to_addr(int family)
1165 {
1166 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1167     if (family == AF_INET6)
1168         return MPTCP_PM_ADDR_ATTR_ADDR6;
1169 #endif
1170     return MPTCP_PM_ADDR_ATTR_ADDR4;
1171 }
1172 
1173 static int mptcp_pm_parse_pm_addr_attr(struct nlattr *tb[],
1174                        const struct nlattr *attr,
1175                        struct genl_info *info,
1176                        struct mptcp_addr_info *addr,
1177                        bool require_family)
1178 {
1179     int err, addr_addr;
1180 
1181     if (!attr) {
1182         GENL_SET_ERR_MSG(info, "missing address info");
1183         return -EINVAL;
1184     }
1185 
1186     /* no validation needed - was already done via nested policy */
1187     err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
1188                       mptcp_pm_addr_policy, info->extack);
1189     if (err)
1190         return err;
1191 
1192     if (tb[MPTCP_PM_ADDR_ATTR_ID])
1193         addr->id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
1194 
1195     if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
1196         if (!require_family)
1197             return err;
1198 
1199         NL_SET_ERR_MSG_ATTR(info->extack, attr,
1200                     "missing family");
1201         return -EINVAL;
1202     }
1203 
1204     addr->family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
1205     if (addr->family != AF_INET
1206 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1207         && addr->family != AF_INET6
1208 #endif
1209         ) {
1210         NL_SET_ERR_MSG_ATTR(info->extack, attr,
1211                     "unknown address family");
1212         return -EINVAL;
1213     }
1214     addr_addr = mptcp_pm_family_to_addr(addr->family);
1215     if (!tb[addr_addr]) {
1216         NL_SET_ERR_MSG_ATTR(info->extack, attr,
1217                     "missing address data");
1218         return -EINVAL;
1219     }
1220 
1221 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1222     if (addr->family == AF_INET6)
1223         addr->addr6 = nla_get_in6_addr(tb[addr_addr]);
1224     else
1225 #endif
1226         addr->addr.s_addr = nla_get_in_addr(tb[addr_addr]);
1227 
1228     if (tb[MPTCP_PM_ADDR_ATTR_PORT])
1229         addr->port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
1230 
1231     return err;
1232 }
1233 
1234 int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
1235             struct mptcp_addr_info *addr)
1236 {
1237     struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
1238 
1239     memset(addr, 0, sizeof(*addr));
1240 
1241     return mptcp_pm_parse_pm_addr_attr(tb, attr, info, addr, true);
1242 }
1243 
1244 int mptcp_pm_parse_entry(struct nlattr *attr, struct genl_info *info,
1245              bool require_family,
1246              struct mptcp_pm_addr_entry *entry)
1247 {
1248     struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
1249     int err;
1250 
1251     memset(entry, 0, sizeof(*entry));
1252 
1253     err = mptcp_pm_parse_pm_addr_attr(tb, attr, info, &entry->addr, require_family);
1254     if (err)
1255         return err;
1256 
1257     if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
1258         u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
1259 
1260         entry->ifindex = val;
1261     }
1262 
1263     if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
1264         entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
1265 
1266     if (tb[MPTCP_PM_ADDR_ATTR_PORT])
1267         entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
1268 
1269     return 0;
1270 }
1271 
1272 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
1273 {
1274     return pm_nl_get_pernet(genl_info_net(info));
1275 }
1276 
1277 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
1278 {
1279     struct mptcp_sock *msk;
1280     long s_slot = 0, s_num = 0;
1281 
1282     while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1283         struct sock *sk = (struct sock *)msk;
1284 
1285         if (!READ_ONCE(msk->fully_established) ||
1286             mptcp_pm_is_userspace(msk))
1287             goto next;
1288 
1289         lock_sock(sk);
1290         spin_lock_bh(&msk->pm.lock);
1291         mptcp_pm_create_subflow_or_signal_addr(msk);
1292         spin_unlock_bh(&msk->pm.lock);
1293         release_sock(sk);
1294 
1295 next:
1296         sock_put(sk);
1297         cond_resched();
1298     }
1299 
1300     return 0;
1301 }
1302 
1303 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
1304 {
1305     struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1306     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1307     struct mptcp_pm_addr_entry addr, *entry;
1308     int ret;
1309 
1310     ret = mptcp_pm_parse_entry(attr, info, true, &addr);
1311     if (ret < 0)
1312         return ret;
1313 
1314     if (addr.addr.port && !(addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
1315         GENL_SET_ERR_MSG(info, "flags must have signal when using port");
1316         return -EINVAL;
1317     }
1318 
1319     if (addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL &&
1320         addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) {
1321         GENL_SET_ERR_MSG(info, "flags mustn't have both signal and fullmesh");
1322         return -EINVAL;
1323     }
1324 
1325     if (addr.flags & MPTCP_PM_ADDR_FLAG_IMPLICIT) {
1326         GENL_SET_ERR_MSG(info, "can't create IMPLICIT endpoint");
1327         return -EINVAL;
1328     }
1329 
1330     entry = kmalloc(sizeof(*entry), GFP_KERNEL);
1331     if (!entry) {
1332         GENL_SET_ERR_MSG(info, "can't allocate addr");
1333         return -ENOMEM;
1334     }
1335 
1336     *entry = addr;
1337     if (entry->addr.port) {
1338         ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
1339         if (ret) {
1340             GENL_SET_ERR_MSG(info, "create listen socket error");
1341             kfree(entry);
1342             return ret;
1343         }
1344     }
1345     ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1346     if (ret < 0) {
1347         GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
1348         if (entry->lsk)
1349             sock_release(entry->lsk);
1350         kfree(entry);
1351         return ret;
1352     }
1353 
1354     mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
1355 
1356     return 0;
1357 }
1358 
1359 int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id,
1360                      u8 *flags, int *ifindex)
1361 {
1362     struct mptcp_pm_addr_entry *entry;
1363     struct sock *sk = (struct sock *)msk;
1364     struct net *net = sock_net(sk);
1365 
1366     *flags = 0;
1367     *ifindex = 0;
1368 
1369     if (id) {
1370         if (mptcp_pm_is_userspace(msk))
1371             return mptcp_userspace_pm_get_flags_and_ifindex_by_id(msk,
1372                                           id,
1373                                           flags,
1374                                           ifindex);
1375 
1376         rcu_read_lock();
1377         entry = __lookup_addr_by_id(pm_nl_get_pernet(net), id);
1378         if (entry) {
1379             *flags = entry->flags;
1380             *ifindex = entry->ifindex;
1381         }
1382         rcu_read_unlock();
1383     }
1384 
1385     return 0;
1386 }
1387 
1388 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
1389                       const struct mptcp_addr_info *addr)
1390 {
1391     struct mptcp_pm_add_entry *entry;
1392 
1393     entry = mptcp_pm_del_add_timer(msk, addr, false);
1394     if (entry) {
1395         list_del(&entry->list);
1396         kfree(entry);
1397         return true;
1398     }
1399 
1400     return false;
1401 }
1402 
1403 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
1404                       const struct mptcp_addr_info *addr,
1405                       bool force)
1406 {
1407     struct mptcp_rm_list list = { .nr = 0 };
1408     bool ret;
1409 
1410     list.ids[list.nr++] = addr->id;
1411 
1412     ret = remove_anno_list_by_saddr(msk, addr);
1413     if (ret || force) {
1414         spin_lock_bh(&msk->pm.lock);
1415         mptcp_pm_remove_addr(msk, &list);
1416         spin_unlock_bh(&msk->pm.lock);
1417     }
1418     return ret;
1419 }
1420 
1421 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
1422                            const struct mptcp_pm_addr_entry *entry)
1423 {
1424     const struct mptcp_addr_info *addr = &entry->addr;
1425     struct mptcp_rm_list list = { .nr = 0 };
1426     long s_slot = 0, s_num = 0;
1427     struct mptcp_sock *msk;
1428 
1429     pr_debug("remove_id=%d", addr->id);
1430 
1431     list.ids[list.nr++] = addr->id;
1432 
1433     while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1434         struct sock *sk = (struct sock *)msk;
1435         bool remove_subflow;
1436 
1437         if (mptcp_pm_is_userspace(msk))
1438             goto next;
1439 
1440         if (list_empty(&msk->conn_list)) {
1441             mptcp_pm_remove_anno_addr(msk, addr, false);
1442             goto next;
1443         }
1444 
1445         lock_sock(sk);
1446         remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
1447         mptcp_pm_remove_anno_addr(msk, addr, remove_subflow &&
1448                       !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT));
1449         if (remove_subflow)
1450             mptcp_pm_remove_subflow(msk, &list);
1451         release_sock(sk);
1452 
1453 next:
1454         sock_put(sk);
1455         cond_resched();
1456     }
1457 
1458     return 0;
1459 }
1460 
1461 static int mptcp_nl_remove_id_zero_address(struct net *net,
1462                        struct mptcp_addr_info *addr)
1463 {
1464     struct mptcp_rm_list list = { .nr = 0 };
1465     long s_slot = 0, s_num = 0;
1466     struct mptcp_sock *msk;
1467 
1468     list.ids[list.nr++] = 0;
1469 
1470     while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1471         struct sock *sk = (struct sock *)msk;
1472         struct mptcp_addr_info msk_local;
1473 
1474         if (list_empty(&msk->conn_list) || mptcp_pm_is_userspace(msk))
1475             goto next;
1476 
1477         local_address((struct sock_common *)msk, &msk_local);
1478         if (!mptcp_addresses_equal(&msk_local, addr, addr->port))
1479             goto next;
1480 
1481         lock_sock(sk);
1482         spin_lock_bh(&msk->pm.lock);
1483         mptcp_pm_remove_addr(msk, &list);
1484         mptcp_pm_nl_rm_subflow_received(msk, &list);
1485         spin_unlock_bh(&msk->pm.lock);
1486         release_sock(sk);
1487 
1488 next:
1489         sock_put(sk);
1490         cond_resched();
1491     }
1492 
1493     return 0;
1494 }
1495 
1496 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
1497 {
1498     struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1499     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1500     struct mptcp_pm_addr_entry addr, *entry;
1501     unsigned int addr_max;
1502     int ret;
1503 
1504     ret = mptcp_pm_parse_entry(attr, info, false, &addr);
1505     if (ret < 0)
1506         return ret;
1507 
1508     /* the zero id address is special: the first address used by the msk
1509      * always gets such an id, so different subflows can have different zero
1510      * id addresses. Additionally zero id is not accounted for in id_bitmap.
1511      * Let's use an 'mptcp_rm_list' instead of the common remove code.
1512      */
1513     if (addr.addr.id == 0)
1514         return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr);
1515 
1516     spin_lock_bh(&pernet->lock);
1517     entry = __lookup_addr_by_id(pernet, addr.addr.id);
1518     if (!entry) {
1519         GENL_SET_ERR_MSG(info, "address not found");
1520         spin_unlock_bh(&pernet->lock);
1521         return -EINVAL;
1522     }
1523     if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
1524         addr_max = pernet->add_addr_signal_max;
1525         WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
1526     }
1527     if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
1528         addr_max = pernet->local_addr_max;
1529         WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
1530     }
1531 
1532     pernet->addrs--;
1533     list_del_rcu(&entry->list);
1534     __clear_bit(entry->addr.id, pernet->id_bitmap);
1535     spin_unlock_bh(&pernet->lock);
1536 
1537     mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), entry);
1538     synchronize_rcu();
1539     __mptcp_pm_release_addr_entry(entry);
1540 
1541     return ret;
1542 }
1543 
1544 void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
1545                     struct list_head *rm_list)
1546 {
1547     struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };
1548     struct mptcp_pm_addr_entry *entry;
1549 
1550     list_for_each_entry(entry, rm_list, list) {
1551         if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
1552             slist.nr < MPTCP_RM_IDS_MAX)
1553             slist.ids[slist.nr++] = entry->addr.id;
1554 
1555         if (remove_anno_list_by_saddr(msk, &entry->addr) &&
1556             alist.nr < MPTCP_RM_IDS_MAX)
1557             alist.ids[alist.nr++] = entry->addr.id;
1558     }
1559 
1560     if (alist.nr) {
1561         spin_lock_bh(&msk->pm.lock);
1562         mptcp_pm_remove_addr(msk, &alist);
1563         spin_unlock_bh(&msk->pm.lock);
1564     }
1565     if (slist.nr)
1566         mptcp_pm_remove_subflow(msk, &slist);
1567 }
1568 
1569 static void mptcp_nl_remove_addrs_list(struct net *net,
1570                        struct list_head *rm_list)
1571 {
1572     long s_slot = 0, s_num = 0;
1573     struct mptcp_sock *msk;
1574 
1575     if (list_empty(rm_list))
1576         return;
1577 
1578     while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1579         struct sock *sk = (struct sock *)msk;
1580 
1581         if (!mptcp_pm_is_userspace(msk)) {
1582             lock_sock(sk);
1583             mptcp_pm_remove_addrs_and_subflows(msk, rm_list);
1584             release_sock(sk);
1585         }
1586 
1587         sock_put(sk);
1588         cond_resched();
1589     }
1590 }
1591 
1592 /* caller must ensure the RCU grace period is already elapsed */
1593 static void __flush_addrs(struct list_head *list)
1594 {
1595     while (!list_empty(list)) {
1596         struct mptcp_pm_addr_entry *cur;
1597 
1598         cur = list_entry(list->next,
1599                  struct mptcp_pm_addr_entry, list);
1600         list_del_rcu(&cur->list);
1601         __mptcp_pm_release_addr_entry(cur);
1602     }
1603 }
1604 
1605 static void __reset_counters(struct pm_nl_pernet *pernet)
1606 {
1607     WRITE_ONCE(pernet->add_addr_signal_max, 0);
1608     WRITE_ONCE(pernet->add_addr_accept_max, 0);
1609     WRITE_ONCE(pernet->local_addr_max, 0);
1610     pernet->addrs = 0;
1611 }
1612 
1613 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
1614 {
1615     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1616     LIST_HEAD(free_list);
1617 
1618     spin_lock_bh(&pernet->lock);
1619     list_splice_init(&pernet->local_addr_list, &free_list);
1620     __reset_counters(pernet);
1621     pernet->next_id = 1;
1622     bitmap_zero(pernet->id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
1623     spin_unlock_bh(&pernet->lock);
1624     mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list);
1625     synchronize_rcu();
1626     __flush_addrs(&free_list);
1627     return 0;
1628 }
1629 
1630 static int mptcp_nl_fill_addr(struct sk_buff *skb,
1631                   struct mptcp_pm_addr_entry *entry)
1632 {
1633     struct mptcp_addr_info *addr = &entry->addr;
1634     struct nlattr *attr;
1635 
1636     attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
1637     if (!attr)
1638         return -EMSGSIZE;
1639 
1640     if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
1641         goto nla_put_failure;
1642     if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
1643         goto nla_put_failure;
1644     if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
1645         goto nla_put_failure;
1646     if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
1647         goto nla_put_failure;
1648     if (entry->ifindex &&
1649         nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
1650         goto nla_put_failure;
1651 
1652     if (addr->family == AF_INET &&
1653         nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
1654                 addr->addr.s_addr))
1655         goto nla_put_failure;
1656 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1657     else if (addr->family == AF_INET6 &&
1658          nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
1659         goto nla_put_failure;
1660 #endif
1661     nla_nest_end(skb, attr);
1662     return 0;
1663 
1664 nla_put_failure:
1665     nla_nest_cancel(skb, attr);
1666     return -EMSGSIZE;
1667 }
1668 
1669 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
1670 {
1671     struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1672     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1673     struct mptcp_pm_addr_entry addr, *entry;
1674     struct sk_buff *msg;
1675     void *reply;
1676     int ret;
1677 
1678     ret = mptcp_pm_parse_entry(attr, info, false, &addr);
1679     if (ret < 0)
1680         return ret;
1681 
1682     msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1683     if (!msg)
1684         return -ENOMEM;
1685 
1686     reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1687                   info->genlhdr->cmd);
1688     if (!reply) {
1689         GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1690         ret = -EMSGSIZE;
1691         goto fail;
1692     }
1693 
1694     spin_lock_bh(&pernet->lock);
1695     entry = __lookup_addr_by_id(pernet, addr.addr.id);
1696     if (!entry) {
1697         GENL_SET_ERR_MSG(info, "address not found");
1698         ret = -EINVAL;
1699         goto unlock_fail;
1700     }
1701 
1702     ret = mptcp_nl_fill_addr(msg, entry);
1703     if (ret)
1704         goto unlock_fail;
1705 
1706     genlmsg_end(msg, reply);
1707     ret = genlmsg_reply(msg, info);
1708     spin_unlock_bh(&pernet->lock);
1709     return ret;
1710 
1711 unlock_fail:
1712     spin_unlock_bh(&pernet->lock);
1713 
1714 fail:
1715     nlmsg_free(msg);
1716     return ret;
1717 }
1718 
1719 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
1720                    struct netlink_callback *cb)
1721 {
1722     struct net *net = sock_net(msg->sk);
1723     struct mptcp_pm_addr_entry *entry;
1724     struct pm_nl_pernet *pernet;
1725     int id = cb->args[0];
1726     void *hdr;
1727     int i;
1728 
1729     pernet = pm_nl_get_pernet(net);
1730 
1731     spin_lock_bh(&pernet->lock);
1732     for (i = id; i < MPTCP_PM_MAX_ADDR_ID + 1; i++) {
1733         if (test_bit(i, pernet->id_bitmap)) {
1734             entry = __lookup_addr_by_id(pernet, i);
1735             if (!entry)
1736                 break;
1737 
1738             if (entry->addr.id <= id)
1739                 continue;
1740 
1741             hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
1742                       cb->nlh->nlmsg_seq, &mptcp_genl_family,
1743                       NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
1744             if (!hdr)
1745                 break;
1746 
1747             if (mptcp_nl_fill_addr(msg, entry) < 0) {
1748                 genlmsg_cancel(msg, hdr);
1749                 break;
1750             }
1751 
1752             id = entry->addr.id;
1753             genlmsg_end(msg, hdr);
1754         }
1755     }
1756     spin_unlock_bh(&pernet->lock);
1757 
1758     cb->args[0] = id;
1759     return msg->len;
1760 }
1761 
1762 static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
1763 {
1764     struct nlattr *attr = info->attrs[id];
1765 
1766     if (!attr)
1767         return 0;
1768 
1769     *limit = nla_get_u32(attr);
1770     if (*limit > MPTCP_PM_ADDR_MAX) {
1771         GENL_SET_ERR_MSG(info, "limit greater than maximum");
1772         return -EINVAL;
1773     }
1774     return 0;
1775 }
1776 
1777 static int
1778 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
1779 {
1780     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1781     unsigned int rcv_addrs, subflows;
1782     int ret;
1783 
1784     spin_lock_bh(&pernet->lock);
1785     rcv_addrs = pernet->add_addr_accept_max;
1786     ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
1787     if (ret)
1788         goto unlock;
1789 
1790     subflows = pernet->subflows_max;
1791     ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
1792     if (ret)
1793         goto unlock;
1794 
1795     WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
1796     WRITE_ONCE(pernet->subflows_max, subflows);
1797 
1798 unlock:
1799     spin_unlock_bh(&pernet->lock);
1800     return ret;
1801 }
1802 
1803 static int
1804 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
1805 {
1806     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1807     struct sk_buff *msg;
1808     void *reply;
1809 
1810     msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1811     if (!msg)
1812         return -ENOMEM;
1813 
1814     reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1815                   MPTCP_PM_CMD_GET_LIMITS);
1816     if (!reply)
1817         goto fail;
1818 
1819     if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
1820             READ_ONCE(pernet->add_addr_accept_max)))
1821         goto fail;
1822 
1823     if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
1824             READ_ONCE(pernet->subflows_max)))
1825         goto fail;
1826 
1827     genlmsg_end(msg, reply);
1828     return genlmsg_reply(msg, info);
1829 
1830 fail:
1831     GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1832     nlmsg_free(msg);
1833     return -EMSGSIZE;
1834 }
1835 
1836 static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk,
1837                  struct mptcp_addr_info *addr)
1838 {
1839     struct mptcp_rm_list list = { .nr = 0 };
1840 
1841     list.ids[list.nr++] = addr->id;
1842 
1843     spin_lock_bh(&msk->pm.lock);
1844     mptcp_pm_nl_rm_subflow_received(msk, &list);
1845     mptcp_pm_create_subflow_or_signal_addr(msk);
1846     spin_unlock_bh(&msk->pm.lock);
1847 }
1848 
1849 static int mptcp_nl_set_flags(struct net *net,
1850                   struct mptcp_addr_info *addr,
1851                   u8 bkup, u8 changed)
1852 {
1853     long s_slot = 0, s_num = 0;
1854     struct mptcp_sock *msk;
1855     int ret = -EINVAL;
1856 
1857     while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1858         struct sock *sk = (struct sock *)msk;
1859 
1860         if (list_empty(&msk->conn_list) || mptcp_pm_is_userspace(msk))
1861             goto next;
1862 
1863         lock_sock(sk);
1864         if (changed & MPTCP_PM_ADDR_FLAG_BACKUP)
1865             ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, NULL, bkup);
1866         if (changed & MPTCP_PM_ADDR_FLAG_FULLMESH)
1867             mptcp_pm_nl_fullmesh(msk, addr);
1868         release_sock(sk);
1869 
1870 next:
1871         sock_put(sk);
1872         cond_resched();
1873     }
1874 
1875     return ret;
1876 }
1877 
1878 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
1879 {
1880     struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry;
1881     struct mptcp_pm_addr_entry remote = { .addr = { .family = AF_UNSPEC }, };
1882     struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
1883     struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
1884     struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1885     struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1886     u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP |
1887                MPTCP_PM_ADDR_FLAG_FULLMESH;
1888     struct net *net = sock_net(skb->sk);
1889     u8 bkup = 0, lookup_by_id = 0;
1890     int ret;
1891 
1892     ret = mptcp_pm_parse_entry(attr, info, false, &addr);
1893     if (ret < 0)
1894         return ret;
1895 
1896     if (attr_rem) {
1897         ret = mptcp_pm_parse_entry(attr_rem, info, false, &remote);
1898         if (ret < 0)
1899             return ret;
1900     }
1901 
1902     if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
1903         bkup = 1;
1904     if (addr.addr.family == AF_UNSPEC) {
1905         lookup_by_id = 1;
1906         if (!addr.addr.id)
1907             return -EOPNOTSUPP;
1908     }
1909 
1910     if (token)
1911         return mptcp_userspace_pm_set_flags(sock_net(skb->sk),
1912                             token, &addr, &remote, bkup);
1913 
1914     spin_lock_bh(&pernet->lock);
1915     entry = __lookup_addr(pernet, &addr.addr, lookup_by_id);
1916     if (!entry) {
1917         spin_unlock_bh(&pernet->lock);
1918         return -EINVAL;
1919     }
1920     if ((addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) &&
1921         (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
1922         spin_unlock_bh(&pernet->lock);
1923         return -EINVAL;
1924     }
1925 
1926     changed = (addr.flags ^ entry->flags) & mask;
1927     entry->flags = (entry->flags & ~mask) | (addr.flags & mask);
1928     addr = *entry;
1929     spin_unlock_bh(&pernet->lock);
1930 
1931     mptcp_nl_set_flags(net, &addr.addr, bkup, changed);
1932     return 0;
1933 }
1934 
1935 static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
1936 {
1937     genlmsg_multicast_netns(&mptcp_genl_family, net,
1938                 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
1939 }
1940 
1941 bool mptcp_userspace_pm_active(const struct mptcp_sock *msk)
1942 {
1943     return genl_has_listeners(&mptcp_genl_family,
1944                   sock_net((const struct sock *)msk),
1945                   MPTCP_PM_EV_GRP_OFFSET);
1946 }
1947 
1948 static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
1949 {
1950     const struct inet_sock *issk = inet_sk(ssk);
1951     const struct mptcp_subflow_context *sf;
1952 
1953     if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
1954         return -EMSGSIZE;
1955 
1956     switch (ssk->sk_family) {
1957     case AF_INET:
1958         if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
1959             return -EMSGSIZE;
1960         if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
1961             return -EMSGSIZE;
1962         break;
1963 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1964     case AF_INET6: {
1965         const struct ipv6_pinfo *np = inet6_sk(ssk);
1966 
1967         if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
1968             return -EMSGSIZE;
1969         if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
1970             return -EMSGSIZE;
1971         break;
1972     }
1973 #endif
1974     default:
1975         WARN_ON_ONCE(1);
1976         return -EMSGSIZE;
1977     }
1978 
1979     if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
1980         return -EMSGSIZE;
1981     if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
1982         return -EMSGSIZE;
1983 
1984     sf = mptcp_subflow_ctx(ssk);
1985     if (WARN_ON_ONCE(!sf))
1986         return -EINVAL;
1987 
1988     if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
1989         return -EMSGSIZE;
1990 
1991     if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
1992         return -EMSGSIZE;
1993 
1994     return 0;
1995 }
1996 
1997 static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
1998                      const struct mptcp_sock *msk,
1999                      const struct sock *ssk)
2000 {
2001     const struct sock *sk = (const struct sock *)msk;
2002     const struct mptcp_subflow_context *sf;
2003     u8 sk_err;
2004 
2005     if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
2006         return -EMSGSIZE;
2007 
2008     if (mptcp_event_add_subflow(skb, ssk))
2009         return -EMSGSIZE;
2010 
2011     sf = mptcp_subflow_ctx(ssk);
2012     if (WARN_ON_ONCE(!sf))
2013         return -EINVAL;
2014 
2015     if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
2016         return -EMSGSIZE;
2017 
2018     if (ssk->sk_bound_dev_if &&
2019         nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
2020         return -EMSGSIZE;
2021 
2022     sk_err = ssk->sk_err;
2023     if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
2024         nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
2025         return -EMSGSIZE;
2026 
2027     return 0;
2028 }
2029 
2030 static int mptcp_event_sub_established(struct sk_buff *skb,
2031                        const struct mptcp_sock *msk,
2032                        const struct sock *ssk)
2033 {
2034     return mptcp_event_put_token_and_ssk(skb, msk, ssk);
2035 }
2036 
2037 static int mptcp_event_sub_closed(struct sk_buff *skb,
2038                   const struct mptcp_sock *msk,
2039                   const struct sock *ssk)
2040 {
2041     const struct mptcp_subflow_context *sf;
2042 
2043     if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
2044         return -EMSGSIZE;
2045 
2046     sf = mptcp_subflow_ctx(ssk);
2047     if (!sf->reset_seen)
2048         return 0;
2049 
2050     if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason))
2051         return -EMSGSIZE;
2052 
2053     if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient))
2054         return -EMSGSIZE;
2055 
2056     return 0;
2057 }
2058 
2059 static int mptcp_event_created(struct sk_buff *skb,
2060                    const struct mptcp_sock *msk,
2061                    const struct sock *ssk)
2062 {
2063     int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
2064 
2065     if (err)
2066         return err;
2067 
2068     if (nla_put_u8(skb, MPTCP_ATTR_SERVER_SIDE, READ_ONCE(msk->pm.server_side)))
2069         return -EMSGSIZE;
2070 
2071     return mptcp_event_add_subflow(skb, ssk);
2072 }
2073 
2074 void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
2075 {
2076     struct net *net = sock_net((const struct sock *)msk);
2077     struct nlmsghdr *nlh;
2078     struct sk_buff *skb;
2079 
2080     if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
2081         return;
2082 
2083     skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
2084     if (!skb)
2085         return;
2086 
2087     nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
2088     if (!nlh)
2089         goto nla_put_failure;
2090 
2091     if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
2092         goto nla_put_failure;
2093 
2094     if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
2095         goto nla_put_failure;
2096 
2097     genlmsg_end(skb, nlh);
2098     mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
2099     return;
2100 
2101 nla_put_failure:
2102     kfree_skb(skb);
2103 }
2104 
2105 void mptcp_event_addr_announced(const struct sock *ssk,
2106                 const struct mptcp_addr_info *info)
2107 {
2108     struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
2109     struct mptcp_sock *msk = mptcp_sk(subflow->conn);
2110     struct net *net = sock_net(ssk);
2111     struct nlmsghdr *nlh;
2112     struct sk_buff *skb;
2113 
2114     if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
2115         return;
2116 
2117     skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
2118     if (!skb)
2119         return;
2120 
2121     nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
2122               MPTCP_EVENT_ANNOUNCED);
2123     if (!nlh)
2124         goto nla_put_failure;
2125 
2126     if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
2127         goto nla_put_failure;
2128 
2129     if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
2130         goto nla_put_failure;
2131 
2132     if (nla_put_be16(skb, MPTCP_ATTR_DPORT,
2133              info->port == 0 ?
2134              inet_sk(ssk)->inet_dport :
2135              info->port))
2136         goto nla_put_failure;
2137 
2138     switch (info->family) {
2139     case AF_INET:
2140         if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
2141             goto nla_put_failure;
2142         break;
2143 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
2144     case AF_INET6:
2145         if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
2146             goto nla_put_failure;
2147         break;
2148 #endif
2149     default:
2150         WARN_ON_ONCE(1);
2151         goto nla_put_failure;
2152     }
2153 
2154     genlmsg_end(skb, nlh);
2155     mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
2156     return;
2157 
2158 nla_put_failure:
2159     kfree_skb(skb);
2160 }
2161 
2162 void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
2163          const struct sock *ssk, gfp_t gfp)
2164 {
2165     struct net *net = sock_net((const struct sock *)msk);
2166     struct nlmsghdr *nlh;
2167     struct sk_buff *skb;
2168 
2169     if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
2170         return;
2171 
2172     skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
2173     if (!skb)
2174         return;
2175 
2176     nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
2177     if (!nlh)
2178         goto nla_put_failure;
2179 
2180     switch (type) {
2181     case MPTCP_EVENT_UNSPEC:
2182         WARN_ON_ONCE(1);
2183         break;
2184     case MPTCP_EVENT_CREATED:
2185     case MPTCP_EVENT_ESTABLISHED:
2186         if (mptcp_event_created(skb, msk, ssk) < 0)
2187             goto nla_put_failure;
2188         break;
2189     case MPTCP_EVENT_CLOSED:
2190         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
2191             goto nla_put_failure;
2192         break;
2193     case MPTCP_EVENT_ANNOUNCED:
2194     case MPTCP_EVENT_REMOVED:
2195         /* call mptcp_event_addr_announced()/removed instead */
2196         WARN_ON_ONCE(1);
2197         break;
2198     case MPTCP_EVENT_SUB_ESTABLISHED:
2199     case MPTCP_EVENT_SUB_PRIORITY:
2200         if (mptcp_event_sub_established(skb, msk, ssk) < 0)
2201             goto nla_put_failure;
2202         break;
2203     case MPTCP_EVENT_SUB_CLOSED:
2204         if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
2205             goto nla_put_failure;
2206         break;
2207     }
2208 
2209     genlmsg_end(skb, nlh);
2210     mptcp_nl_mcast_send(net, skb, gfp);
2211     return;
2212 
2213 nla_put_failure:
2214     kfree_skb(skb);
2215 }
2216 
2217 static const struct genl_small_ops mptcp_pm_ops[] = {
2218     {
2219         .cmd    = MPTCP_PM_CMD_ADD_ADDR,
2220         .doit   = mptcp_nl_cmd_add_addr,
2221         .flags  = GENL_ADMIN_PERM,
2222     },
2223     {
2224         .cmd    = MPTCP_PM_CMD_DEL_ADDR,
2225         .doit   = mptcp_nl_cmd_del_addr,
2226         .flags  = GENL_ADMIN_PERM,
2227     },
2228     {
2229         .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
2230         .doit   = mptcp_nl_cmd_flush_addrs,
2231         .flags  = GENL_ADMIN_PERM,
2232     },
2233     {
2234         .cmd    = MPTCP_PM_CMD_GET_ADDR,
2235         .doit   = mptcp_nl_cmd_get_addr,
2236         .dumpit   = mptcp_nl_cmd_dump_addrs,
2237     },
2238     {
2239         .cmd    = MPTCP_PM_CMD_SET_LIMITS,
2240         .doit   = mptcp_nl_cmd_set_limits,
2241         .flags  = GENL_ADMIN_PERM,
2242     },
2243     {
2244         .cmd    = MPTCP_PM_CMD_GET_LIMITS,
2245         .doit   = mptcp_nl_cmd_get_limits,
2246     },
2247     {
2248         .cmd    = MPTCP_PM_CMD_SET_FLAGS,
2249         .doit   = mptcp_nl_cmd_set_flags,
2250         .flags  = GENL_ADMIN_PERM,
2251     },
2252     {
2253         .cmd    = MPTCP_PM_CMD_ANNOUNCE,
2254         .doit   = mptcp_nl_cmd_announce,
2255         .flags  = GENL_ADMIN_PERM,
2256     },
2257     {
2258         .cmd    = MPTCP_PM_CMD_REMOVE,
2259         .doit   = mptcp_nl_cmd_remove,
2260         .flags  = GENL_ADMIN_PERM,
2261     },
2262     {
2263         .cmd    = MPTCP_PM_CMD_SUBFLOW_CREATE,
2264         .doit   = mptcp_nl_cmd_sf_create,
2265         .flags  = GENL_ADMIN_PERM,
2266     },
2267     {
2268         .cmd    = MPTCP_PM_CMD_SUBFLOW_DESTROY,
2269         .doit   = mptcp_nl_cmd_sf_destroy,
2270         .flags  = GENL_ADMIN_PERM,
2271     },
2272 };
2273 
2274 static struct genl_family mptcp_genl_family __ro_after_init = {
2275     .name       = MPTCP_PM_NAME,
2276     .version    = MPTCP_PM_VER,
2277     .maxattr    = MPTCP_PM_ATTR_MAX,
2278     .policy     = mptcp_pm_policy,
2279     .netnsok    = true,
2280     .module     = THIS_MODULE,
2281     .small_ops  = mptcp_pm_ops,
2282     .n_small_ops    = ARRAY_SIZE(mptcp_pm_ops),
2283     .mcgrps     = mptcp_pm_mcgrps,
2284     .n_mcgrps   = ARRAY_SIZE(mptcp_pm_mcgrps),
2285 };
2286 
2287 static int __net_init pm_nl_init_net(struct net *net)
2288 {
2289     struct pm_nl_pernet *pernet = pm_nl_get_pernet(net);
2290 
2291     INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
2292 
2293     /* Cit. 2 subflows ought to be enough for anybody. */
2294     pernet->subflows_max = 2;
2295     pernet->next_id = 1;
2296     pernet->stale_loss_cnt = 4;
2297     spin_lock_init(&pernet->lock);
2298 
2299     /* No need to initialize other pernet fields, the struct is zeroed at
2300      * allocation time.
2301      */
2302 
2303     return 0;
2304 }
2305 
2306 static void __net_exit pm_nl_exit_net(struct list_head *net_list)
2307 {
2308     struct net *net;
2309 
2310     list_for_each_entry(net, net_list, exit_list) {
2311         struct pm_nl_pernet *pernet = pm_nl_get_pernet(net);
2312 
2313         /* net is removed from namespace list, can't race with
2314          * other modifiers, also netns core already waited for a
2315          * RCU grace period.
2316          */
2317         __flush_addrs(&pernet->local_addr_list);
2318     }
2319 }
2320 
2321 static struct pernet_operations mptcp_pm_pernet_ops = {
2322     .init = pm_nl_init_net,
2323     .exit_batch = pm_nl_exit_net,
2324     .id = &pm_nl_pernet_id,
2325     .size = sizeof(struct pm_nl_pernet),
2326 };
2327 
2328 void __init mptcp_pm_nl_init(void)
2329 {
2330     if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
2331         panic("Failed to register MPTCP PM pernet subsystem.\n");
2332 
2333     if (genl_register_family(&mptcp_genl_family))
2334         panic("Failed to register MPTCP PM netlink family\n");
2335 }