Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 #include <linux/rcupdate.h>
0003 #include <linux/spinlock.h>
0004 #include <linux/jiffies.h>
0005 #include <linux/module.h>
0006 #include <linux/cache.h>
0007 #include <linux/slab.h>
0008 #include <linux/init.h>
0009 #include <linux/tcp.h>
0010 #include <linux/hash.h>
0011 #include <linux/tcp_metrics.h>
0012 #include <linux/vmalloc.h>
0013 
0014 #include <net/inet_connection_sock.h>
0015 #include <net/net_namespace.h>
0016 #include <net/request_sock.h>
0017 #include <net/inetpeer.h>
0018 #include <net/sock.h>
0019 #include <net/ipv6.h>
0020 #include <net/dst.h>
0021 #include <net/tcp.h>
0022 #include <net/genetlink.h>
0023 
0024 static struct tcp_metrics_block *__tcp_get_metrics(const struct inetpeer_addr *saddr,
0025                            const struct inetpeer_addr *daddr,
0026                            struct net *net, unsigned int hash);
0027 
0028 struct tcp_fastopen_metrics {
0029     u16 mss;
0030     u16 syn_loss:10,        /* Recurring Fast Open SYN losses */
0031         try_exp:2;      /* Request w/ exp. option (once) */
0032     unsigned long   last_syn_loss;  /* Last Fast Open SYN loss */
0033     struct  tcp_fastopen_cookie cookie;
0034 };
0035 
0036 /* TCP_METRIC_MAX includes 2 extra fields for userspace compatibility
0037  * Kernel only stores RTT and RTTVAR in usec resolution
0038  */
0039 #define TCP_METRIC_MAX_KERNEL (TCP_METRIC_MAX - 2)
0040 
0041 struct tcp_metrics_block {
0042     struct tcp_metrics_block __rcu  *tcpm_next;
0043     possible_net_t          tcpm_net;
0044     struct inetpeer_addr        tcpm_saddr;
0045     struct inetpeer_addr        tcpm_daddr;
0046     unsigned long           tcpm_stamp;
0047     u32             tcpm_lock;
0048     u32             tcpm_vals[TCP_METRIC_MAX_KERNEL + 1];
0049     struct tcp_fastopen_metrics tcpm_fastopen;
0050 
0051     struct rcu_head         rcu_head;
0052 };
0053 
0054 static inline struct net *tm_net(struct tcp_metrics_block *tm)
0055 {
0056     return read_pnet(&tm->tcpm_net);
0057 }
0058 
0059 static bool tcp_metric_locked(struct tcp_metrics_block *tm,
0060                   enum tcp_metric_index idx)
0061 {
0062     return tm->tcpm_lock & (1 << idx);
0063 }
0064 
0065 static u32 tcp_metric_get(struct tcp_metrics_block *tm,
0066               enum tcp_metric_index idx)
0067 {
0068     return tm->tcpm_vals[idx];
0069 }
0070 
0071 static void tcp_metric_set(struct tcp_metrics_block *tm,
0072                enum tcp_metric_index idx,
0073                u32 val)
0074 {
0075     tm->tcpm_vals[idx] = val;
0076 }
0077 
0078 static bool addr_same(const struct inetpeer_addr *a,
0079               const struct inetpeer_addr *b)
0080 {
0081     return inetpeer_addr_cmp(a, b) == 0;
0082 }
0083 
0084 struct tcpm_hash_bucket {
0085     struct tcp_metrics_block __rcu  *chain;
0086 };
0087 
0088 static struct tcpm_hash_bucket  *tcp_metrics_hash __read_mostly;
0089 static unsigned int     tcp_metrics_hash_log __read_mostly;
0090 
0091 static DEFINE_SPINLOCK(tcp_metrics_lock);
0092 
0093 static void tcpm_suck_dst(struct tcp_metrics_block *tm,
0094               const struct dst_entry *dst,
0095               bool fastopen_clear)
0096 {
0097     u32 msval;
0098     u32 val;
0099 
0100     tm->tcpm_stamp = jiffies;
0101 
0102     val = 0;
0103     if (dst_metric_locked(dst, RTAX_RTT))
0104         val |= 1 << TCP_METRIC_RTT;
0105     if (dst_metric_locked(dst, RTAX_RTTVAR))
0106         val |= 1 << TCP_METRIC_RTTVAR;
0107     if (dst_metric_locked(dst, RTAX_SSTHRESH))
0108         val |= 1 << TCP_METRIC_SSTHRESH;
0109     if (dst_metric_locked(dst, RTAX_CWND))
0110         val |= 1 << TCP_METRIC_CWND;
0111     if (dst_metric_locked(dst, RTAX_REORDERING))
0112         val |= 1 << TCP_METRIC_REORDERING;
0113     tm->tcpm_lock = val;
0114 
0115     msval = dst_metric_raw(dst, RTAX_RTT);
0116     tm->tcpm_vals[TCP_METRIC_RTT] = msval * USEC_PER_MSEC;
0117 
0118     msval = dst_metric_raw(dst, RTAX_RTTVAR);
0119     tm->tcpm_vals[TCP_METRIC_RTTVAR] = msval * USEC_PER_MSEC;
0120     tm->tcpm_vals[TCP_METRIC_SSTHRESH] = dst_metric_raw(dst, RTAX_SSTHRESH);
0121     tm->tcpm_vals[TCP_METRIC_CWND] = dst_metric_raw(dst, RTAX_CWND);
0122     tm->tcpm_vals[TCP_METRIC_REORDERING] = dst_metric_raw(dst, RTAX_REORDERING);
0123     if (fastopen_clear) {
0124         tm->tcpm_fastopen.mss = 0;
0125         tm->tcpm_fastopen.syn_loss = 0;
0126         tm->tcpm_fastopen.try_exp = 0;
0127         tm->tcpm_fastopen.cookie.exp = false;
0128         tm->tcpm_fastopen.cookie.len = 0;
0129     }
0130 }
0131 
0132 #define TCP_METRICS_TIMEOUT     (60 * 60 * HZ)
0133 
0134 static void tcpm_check_stamp(struct tcp_metrics_block *tm, struct dst_entry *dst)
0135 {
0136     if (tm && unlikely(time_after(jiffies, tm->tcpm_stamp + TCP_METRICS_TIMEOUT)))
0137         tcpm_suck_dst(tm, dst, false);
0138 }
0139 
0140 #define TCP_METRICS_RECLAIM_DEPTH   5
0141 #define TCP_METRICS_RECLAIM_PTR     (struct tcp_metrics_block *) 0x1UL
0142 
0143 #define deref_locked(p) \
0144     rcu_dereference_protected(p, lockdep_is_held(&tcp_metrics_lock))
0145 
0146 static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst,
0147                       struct inetpeer_addr *saddr,
0148                       struct inetpeer_addr *daddr,
0149                       unsigned int hash)
0150 {
0151     struct tcp_metrics_block *tm;
0152     struct net *net;
0153     bool reclaim = false;
0154 
0155     spin_lock_bh(&tcp_metrics_lock);
0156     net = dev_net(dst->dev);
0157 
0158     /* While waiting for the spin-lock the cache might have been populated
0159      * with this entry and so we have to check again.
0160      */
0161     tm = __tcp_get_metrics(saddr, daddr, net, hash);
0162     if (tm == TCP_METRICS_RECLAIM_PTR) {
0163         reclaim = true;
0164         tm = NULL;
0165     }
0166     if (tm) {
0167         tcpm_check_stamp(tm, dst);
0168         goto out_unlock;
0169     }
0170 
0171     if (unlikely(reclaim)) {
0172         struct tcp_metrics_block *oldest;
0173 
0174         oldest = deref_locked(tcp_metrics_hash[hash].chain);
0175         for (tm = deref_locked(oldest->tcpm_next); tm;
0176              tm = deref_locked(tm->tcpm_next)) {
0177             if (time_before(tm->tcpm_stamp, oldest->tcpm_stamp))
0178                 oldest = tm;
0179         }
0180         tm = oldest;
0181     } else {
0182         tm = kmalloc(sizeof(*tm), GFP_ATOMIC);
0183         if (!tm)
0184             goto out_unlock;
0185     }
0186     write_pnet(&tm->tcpm_net, net);
0187     tm->tcpm_saddr = *saddr;
0188     tm->tcpm_daddr = *daddr;
0189 
0190     tcpm_suck_dst(tm, dst, true);
0191 
0192     if (likely(!reclaim)) {
0193         tm->tcpm_next = tcp_metrics_hash[hash].chain;
0194         rcu_assign_pointer(tcp_metrics_hash[hash].chain, tm);
0195     }
0196 
0197 out_unlock:
0198     spin_unlock_bh(&tcp_metrics_lock);
0199     return tm;
0200 }
0201 
0202 static struct tcp_metrics_block *tcp_get_encode(struct tcp_metrics_block *tm, int depth)
0203 {
0204     if (tm)
0205         return tm;
0206     if (depth > TCP_METRICS_RECLAIM_DEPTH)
0207         return TCP_METRICS_RECLAIM_PTR;
0208     return NULL;
0209 }
0210 
0211 static struct tcp_metrics_block *__tcp_get_metrics(const struct inetpeer_addr *saddr,
0212                            const struct inetpeer_addr *daddr,
0213                            struct net *net, unsigned int hash)
0214 {
0215     struct tcp_metrics_block *tm;
0216     int depth = 0;
0217 
0218     for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
0219          tm = rcu_dereference(tm->tcpm_next)) {
0220         if (addr_same(&tm->tcpm_saddr, saddr) &&
0221             addr_same(&tm->tcpm_daddr, daddr) &&
0222             net_eq(tm_net(tm), net))
0223             break;
0224         depth++;
0225     }
0226     return tcp_get_encode(tm, depth);
0227 }
0228 
0229 static struct tcp_metrics_block *__tcp_get_metrics_req(struct request_sock *req,
0230                                struct dst_entry *dst)
0231 {
0232     struct tcp_metrics_block *tm;
0233     struct inetpeer_addr saddr, daddr;
0234     unsigned int hash;
0235     struct net *net;
0236 
0237     saddr.family = req->rsk_ops->family;
0238     daddr.family = req->rsk_ops->family;
0239     switch (daddr.family) {
0240     case AF_INET:
0241         inetpeer_set_addr_v4(&saddr, inet_rsk(req)->ir_loc_addr);
0242         inetpeer_set_addr_v4(&daddr, inet_rsk(req)->ir_rmt_addr);
0243         hash = ipv4_addr_hash(inet_rsk(req)->ir_rmt_addr);
0244         break;
0245 #if IS_ENABLED(CONFIG_IPV6)
0246     case AF_INET6:
0247         inetpeer_set_addr_v6(&saddr, &inet_rsk(req)->ir_v6_loc_addr);
0248         inetpeer_set_addr_v6(&daddr, &inet_rsk(req)->ir_v6_rmt_addr);
0249         hash = ipv6_addr_hash(&inet_rsk(req)->ir_v6_rmt_addr);
0250         break;
0251 #endif
0252     default:
0253         return NULL;
0254     }
0255 
0256     net = dev_net(dst->dev);
0257     hash ^= net_hash_mix(net);
0258     hash = hash_32(hash, tcp_metrics_hash_log);
0259 
0260     for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
0261          tm = rcu_dereference(tm->tcpm_next)) {
0262         if (addr_same(&tm->tcpm_saddr, &saddr) &&
0263             addr_same(&tm->tcpm_daddr, &daddr) &&
0264             net_eq(tm_net(tm), net))
0265             break;
0266     }
0267     tcpm_check_stamp(tm, dst);
0268     return tm;
0269 }
0270 
0271 static struct tcp_metrics_block *tcp_get_metrics(struct sock *sk,
0272                          struct dst_entry *dst,
0273                          bool create)
0274 {
0275     struct tcp_metrics_block *tm;
0276     struct inetpeer_addr saddr, daddr;
0277     unsigned int hash;
0278     struct net *net;
0279 
0280     if (sk->sk_family == AF_INET) {
0281         inetpeer_set_addr_v4(&saddr, inet_sk(sk)->inet_saddr);
0282         inetpeer_set_addr_v4(&daddr, inet_sk(sk)->inet_daddr);
0283         hash = ipv4_addr_hash(inet_sk(sk)->inet_daddr);
0284     }
0285 #if IS_ENABLED(CONFIG_IPV6)
0286     else if (sk->sk_family == AF_INET6) {
0287         if (ipv6_addr_v4mapped(&sk->sk_v6_daddr)) {
0288             inetpeer_set_addr_v4(&saddr, inet_sk(sk)->inet_saddr);
0289             inetpeer_set_addr_v4(&daddr, inet_sk(sk)->inet_daddr);
0290             hash = ipv4_addr_hash(inet_sk(sk)->inet_daddr);
0291         } else {
0292             inetpeer_set_addr_v6(&saddr, &sk->sk_v6_rcv_saddr);
0293             inetpeer_set_addr_v6(&daddr, &sk->sk_v6_daddr);
0294             hash = ipv6_addr_hash(&sk->sk_v6_daddr);
0295         }
0296     }
0297 #endif
0298     else
0299         return NULL;
0300 
0301     net = dev_net(dst->dev);
0302     hash ^= net_hash_mix(net);
0303     hash = hash_32(hash, tcp_metrics_hash_log);
0304 
0305     tm = __tcp_get_metrics(&saddr, &daddr, net, hash);
0306     if (tm == TCP_METRICS_RECLAIM_PTR)
0307         tm = NULL;
0308     if (!tm && create)
0309         tm = tcpm_new(dst, &saddr, &daddr, hash);
0310     else
0311         tcpm_check_stamp(tm, dst);
0312 
0313     return tm;
0314 }
0315 
0316 /* Save metrics learned by this TCP session.  This function is called
0317  * only, when TCP finishes successfully i.e. when it enters TIME-WAIT
0318  * or goes from LAST-ACK to CLOSE.
0319  */
0320 void tcp_update_metrics(struct sock *sk)
0321 {
0322     const struct inet_connection_sock *icsk = inet_csk(sk);
0323     struct dst_entry *dst = __sk_dst_get(sk);
0324     struct tcp_sock *tp = tcp_sk(sk);
0325     struct net *net = sock_net(sk);
0326     struct tcp_metrics_block *tm;
0327     unsigned long rtt;
0328     u32 val;
0329     int m;
0330 
0331     sk_dst_confirm(sk);
0332     if (READ_ONCE(net->ipv4.sysctl_tcp_nometrics_save) || !dst)
0333         return;
0334 
0335     rcu_read_lock();
0336     if (icsk->icsk_backoff || !tp->srtt_us) {
0337         /* This session failed to estimate rtt. Why?
0338          * Probably, no packets returned in time.  Reset our
0339          * results.
0340          */
0341         tm = tcp_get_metrics(sk, dst, false);
0342         if (tm && !tcp_metric_locked(tm, TCP_METRIC_RTT))
0343             tcp_metric_set(tm, TCP_METRIC_RTT, 0);
0344         goto out_unlock;
0345     } else
0346         tm = tcp_get_metrics(sk, dst, true);
0347 
0348     if (!tm)
0349         goto out_unlock;
0350 
0351     rtt = tcp_metric_get(tm, TCP_METRIC_RTT);
0352     m = rtt - tp->srtt_us;
0353 
0354     /* If newly calculated rtt larger than stored one, store new
0355      * one. Otherwise, use EWMA. Remember, rtt overestimation is
0356      * always better than underestimation.
0357      */
0358     if (!tcp_metric_locked(tm, TCP_METRIC_RTT)) {
0359         if (m <= 0)
0360             rtt = tp->srtt_us;
0361         else
0362             rtt -= (m >> 3);
0363         tcp_metric_set(tm, TCP_METRIC_RTT, rtt);
0364     }
0365 
0366     if (!tcp_metric_locked(tm, TCP_METRIC_RTTVAR)) {
0367         unsigned long var;
0368 
0369         if (m < 0)
0370             m = -m;
0371 
0372         /* Scale deviation to rttvar fixed point */
0373         m >>= 1;
0374         if (m < tp->mdev_us)
0375             m = tp->mdev_us;
0376 
0377         var = tcp_metric_get(tm, TCP_METRIC_RTTVAR);
0378         if (m >= var)
0379             var = m;
0380         else
0381             var -= (var - m) >> 2;
0382 
0383         tcp_metric_set(tm, TCP_METRIC_RTTVAR, var);
0384     }
0385 
0386     if (tcp_in_initial_slowstart(tp)) {
0387         /* Slow start still did not finish. */
0388         if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) &&
0389             !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) {
0390             val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH);
0391             if (val && (tcp_snd_cwnd(tp) >> 1) > val)
0392                 tcp_metric_set(tm, TCP_METRIC_SSTHRESH,
0393                            tcp_snd_cwnd(tp) >> 1);
0394         }
0395         if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) {
0396             val = tcp_metric_get(tm, TCP_METRIC_CWND);
0397             if (tcp_snd_cwnd(tp) > val)
0398                 tcp_metric_set(tm, TCP_METRIC_CWND,
0399                            tcp_snd_cwnd(tp));
0400         }
0401     } else if (!tcp_in_slow_start(tp) &&
0402            icsk->icsk_ca_state == TCP_CA_Open) {
0403         /* Cong. avoidance phase, cwnd is reliable. */
0404         if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) &&
0405             !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH))
0406             tcp_metric_set(tm, TCP_METRIC_SSTHRESH,
0407                        max(tcp_snd_cwnd(tp) >> 1, tp->snd_ssthresh));
0408         if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) {
0409             val = tcp_metric_get(tm, TCP_METRIC_CWND);
0410             tcp_metric_set(tm, TCP_METRIC_CWND, (val + tcp_snd_cwnd(tp)) >> 1);
0411         }
0412     } else {
0413         /* Else slow start did not finish, cwnd is non-sense,
0414          * ssthresh may be also invalid.
0415          */
0416         if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) {
0417             val = tcp_metric_get(tm, TCP_METRIC_CWND);
0418             tcp_metric_set(tm, TCP_METRIC_CWND,
0419                        (val + tp->snd_ssthresh) >> 1);
0420         }
0421         if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) &&
0422             !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) {
0423             val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH);
0424             if (val && tp->snd_ssthresh > val)
0425                 tcp_metric_set(tm, TCP_METRIC_SSTHRESH,
0426                            tp->snd_ssthresh);
0427         }
0428         if (!tcp_metric_locked(tm, TCP_METRIC_REORDERING)) {
0429             val = tcp_metric_get(tm, TCP_METRIC_REORDERING);
0430             if (val < tp->reordering &&
0431                 tp->reordering !=
0432                 READ_ONCE(net->ipv4.sysctl_tcp_reordering))
0433                 tcp_metric_set(tm, TCP_METRIC_REORDERING,
0434                            tp->reordering);
0435         }
0436     }
0437     tm->tcpm_stamp = jiffies;
0438 out_unlock:
0439     rcu_read_unlock();
0440 }
0441 
0442 /* Initialize metrics on socket. */
0443 
0444 void tcp_init_metrics(struct sock *sk)
0445 {
0446     struct dst_entry *dst = __sk_dst_get(sk);
0447     struct tcp_sock *tp = tcp_sk(sk);
0448     struct net *net = sock_net(sk);
0449     struct tcp_metrics_block *tm;
0450     u32 val, crtt = 0; /* cached RTT scaled by 8 */
0451 
0452     sk_dst_confirm(sk);
0453     if (!dst)
0454         goto reset;
0455 
0456     rcu_read_lock();
0457     tm = tcp_get_metrics(sk, dst, true);
0458     if (!tm) {
0459         rcu_read_unlock();
0460         goto reset;
0461     }
0462 
0463     if (tcp_metric_locked(tm, TCP_METRIC_CWND))
0464         tp->snd_cwnd_clamp = tcp_metric_get(tm, TCP_METRIC_CWND);
0465 
0466     val = READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) ?
0467           0 : tcp_metric_get(tm, TCP_METRIC_SSTHRESH);
0468     if (val) {
0469         tp->snd_ssthresh = val;
0470         if (tp->snd_ssthresh > tp->snd_cwnd_clamp)
0471             tp->snd_ssthresh = tp->snd_cwnd_clamp;
0472     } else {
0473         /* ssthresh may have been reduced unnecessarily during.
0474          * 3WHS. Restore it back to its initial default.
0475          */
0476         tp->snd_ssthresh = TCP_INFINITE_SSTHRESH;
0477     }
0478     val = tcp_metric_get(tm, TCP_METRIC_REORDERING);
0479     if (val && tp->reordering != val)
0480         tp->reordering = val;
0481 
0482     crtt = tcp_metric_get(tm, TCP_METRIC_RTT);
0483     rcu_read_unlock();
0484 reset:
0485     /* The initial RTT measurement from the SYN/SYN-ACK is not ideal
0486      * to seed the RTO for later data packets because SYN packets are
0487      * small. Use the per-dst cached values to seed the RTO but keep
0488      * the RTT estimator variables intact (e.g., srtt, mdev, rttvar).
0489      * Later the RTO will be updated immediately upon obtaining the first
0490      * data RTT sample (tcp_rtt_estimator()). Hence the cached RTT only
0491      * influences the first RTO but not later RTT estimation.
0492      *
0493      * But if RTT is not available from the SYN (due to retransmits or
0494      * syn cookies) or the cache, force a conservative 3secs timeout.
0495      *
0496      * A bit of theory. RTT is time passed after "normal" sized packet
0497      * is sent until it is ACKed. In normal circumstances sending small
0498      * packets force peer to delay ACKs and calculation is correct too.
0499      * The algorithm is adaptive and, provided we follow specs, it
0500      * NEVER underestimate RTT. BUT! If peer tries to make some clever
0501      * tricks sort of "quick acks" for time long enough to decrease RTT
0502      * to low value, and then abruptly stops to do it and starts to delay
0503      * ACKs, wait for troubles.
0504      */
0505     if (crtt > tp->srtt_us) {
0506         /* Set RTO like tcp_rtt_estimator(), but from cached RTT. */
0507         crtt /= 8 * USEC_PER_SEC / HZ;
0508         inet_csk(sk)->icsk_rto = crtt + max(2 * crtt, tcp_rto_min(sk));
0509     } else if (tp->srtt_us == 0) {
0510         /* RFC6298: 5.7 We've failed to get a valid RTT sample from
0511          * 3WHS. This is most likely due to retransmission,
0512          * including spurious one. Reset the RTO back to 3secs
0513          * from the more aggressive 1sec to avoid more spurious
0514          * retransmission.
0515          */
0516         tp->rttvar_us = jiffies_to_usecs(TCP_TIMEOUT_FALLBACK);
0517         tp->mdev_us = tp->mdev_max_us = tp->rttvar_us;
0518 
0519         inet_csk(sk)->icsk_rto = TCP_TIMEOUT_FALLBACK;
0520     }
0521 }
0522 
0523 bool tcp_peer_is_proven(struct request_sock *req, struct dst_entry *dst)
0524 {
0525     struct tcp_metrics_block *tm;
0526     bool ret;
0527 
0528     if (!dst)
0529         return false;
0530 
0531     rcu_read_lock();
0532     tm = __tcp_get_metrics_req(req, dst);
0533     if (tm && tcp_metric_get(tm, TCP_METRIC_RTT))
0534         ret = true;
0535     else
0536         ret = false;
0537     rcu_read_unlock();
0538 
0539     return ret;
0540 }
0541 
0542 static DEFINE_SEQLOCK(fastopen_seqlock);
0543 
0544 void tcp_fastopen_cache_get(struct sock *sk, u16 *mss,
0545                 struct tcp_fastopen_cookie *cookie)
0546 {
0547     struct tcp_metrics_block *tm;
0548 
0549     rcu_read_lock();
0550     tm = tcp_get_metrics(sk, __sk_dst_get(sk), false);
0551     if (tm) {
0552         struct tcp_fastopen_metrics *tfom = &tm->tcpm_fastopen;
0553         unsigned int seq;
0554 
0555         do {
0556             seq = read_seqbegin(&fastopen_seqlock);
0557             if (tfom->mss)
0558                 *mss = tfom->mss;
0559             *cookie = tfom->cookie;
0560             if (cookie->len <= 0 && tfom->try_exp == 1)
0561                 cookie->exp = true;
0562         } while (read_seqretry(&fastopen_seqlock, seq));
0563     }
0564     rcu_read_unlock();
0565 }
0566 
0567 void tcp_fastopen_cache_set(struct sock *sk, u16 mss,
0568                 struct tcp_fastopen_cookie *cookie, bool syn_lost,
0569                 u16 try_exp)
0570 {
0571     struct dst_entry *dst = __sk_dst_get(sk);
0572     struct tcp_metrics_block *tm;
0573 
0574     if (!dst)
0575         return;
0576     rcu_read_lock();
0577     tm = tcp_get_metrics(sk, dst, true);
0578     if (tm) {
0579         struct tcp_fastopen_metrics *tfom = &tm->tcpm_fastopen;
0580 
0581         write_seqlock_bh(&fastopen_seqlock);
0582         if (mss)
0583             tfom->mss = mss;
0584         if (cookie && cookie->len > 0)
0585             tfom->cookie = *cookie;
0586         else if (try_exp > tfom->try_exp &&
0587              tfom->cookie.len <= 0 && !tfom->cookie.exp)
0588             tfom->try_exp = try_exp;
0589         if (syn_lost) {
0590             ++tfom->syn_loss;
0591             tfom->last_syn_loss = jiffies;
0592         } else
0593             tfom->syn_loss = 0;
0594         write_sequnlock_bh(&fastopen_seqlock);
0595     }
0596     rcu_read_unlock();
0597 }
0598 
0599 static struct genl_family tcp_metrics_nl_family;
0600 
0601 static const struct nla_policy tcp_metrics_nl_policy[TCP_METRICS_ATTR_MAX + 1] = {
0602     [TCP_METRICS_ATTR_ADDR_IPV4]    = { .type = NLA_U32, },
0603     [TCP_METRICS_ATTR_ADDR_IPV6]    = { .type = NLA_BINARY,
0604                         .len = sizeof(struct in6_addr), },
0605     /* Following attributes are not received for GET/DEL,
0606      * we keep them for reference
0607      */
0608 #if 0
0609     [TCP_METRICS_ATTR_AGE]      = { .type = NLA_MSECS, },
0610     [TCP_METRICS_ATTR_TW_TSVAL] = { .type = NLA_U32, },
0611     [TCP_METRICS_ATTR_TW_TS_STAMP]  = { .type = NLA_S32, },
0612     [TCP_METRICS_ATTR_VALS]     = { .type = NLA_NESTED, },
0613     [TCP_METRICS_ATTR_FOPEN_MSS]    = { .type = NLA_U16, },
0614     [TCP_METRICS_ATTR_FOPEN_SYN_DROPS]  = { .type = NLA_U16, },
0615     [TCP_METRICS_ATTR_FOPEN_SYN_DROP_TS]    = { .type = NLA_MSECS, },
0616     [TCP_METRICS_ATTR_FOPEN_COOKIE] = { .type = NLA_BINARY,
0617                         .len = TCP_FASTOPEN_COOKIE_MAX, },
0618 #endif
0619 };
0620 
0621 /* Add attributes, caller cancels its header on failure */
0622 static int tcp_metrics_fill_info(struct sk_buff *msg,
0623                  struct tcp_metrics_block *tm)
0624 {
0625     struct nlattr *nest;
0626     int i;
0627 
0628     switch (tm->tcpm_daddr.family) {
0629     case AF_INET:
0630         if (nla_put_in_addr(msg, TCP_METRICS_ATTR_ADDR_IPV4,
0631                     inetpeer_get_addr_v4(&tm->tcpm_daddr)) < 0)
0632             goto nla_put_failure;
0633         if (nla_put_in_addr(msg, TCP_METRICS_ATTR_SADDR_IPV4,
0634                     inetpeer_get_addr_v4(&tm->tcpm_saddr)) < 0)
0635             goto nla_put_failure;
0636         break;
0637     case AF_INET6:
0638         if (nla_put_in6_addr(msg, TCP_METRICS_ATTR_ADDR_IPV6,
0639                      inetpeer_get_addr_v6(&tm->tcpm_daddr)) < 0)
0640             goto nla_put_failure;
0641         if (nla_put_in6_addr(msg, TCP_METRICS_ATTR_SADDR_IPV6,
0642                      inetpeer_get_addr_v6(&tm->tcpm_saddr)) < 0)
0643             goto nla_put_failure;
0644         break;
0645     default:
0646         return -EAFNOSUPPORT;
0647     }
0648 
0649     if (nla_put_msecs(msg, TCP_METRICS_ATTR_AGE,
0650               jiffies - tm->tcpm_stamp,
0651               TCP_METRICS_ATTR_PAD) < 0)
0652         goto nla_put_failure;
0653 
0654     {
0655         int n = 0;
0656 
0657         nest = nla_nest_start_noflag(msg, TCP_METRICS_ATTR_VALS);
0658         if (!nest)
0659             goto nla_put_failure;
0660         for (i = 0; i < TCP_METRIC_MAX_KERNEL + 1; i++) {
0661             u32 val = tm->tcpm_vals[i];
0662 
0663             if (!val)
0664                 continue;
0665             if (i == TCP_METRIC_RTT) {
0666                 if (nla_put_u32(msg, TCP_METRIC_RTT_US + 1,
0667                         val) < 0)
0668                     goto nla_put_failure;
0669                 n++;
0670                 val = max(val / 1000, 1U);
0671             }
0672             if (i == TCP_METRIC_RTTVAR) {
0673                 if (nla_put_u32(msg, TCP_METRIC_RTTVAR_US + 1,
0674                         val) < 0)
0675                     goto nla_put_failure;
0676                 n++;
0677                 val = max(val / 1000, 1U);
0678             }
0679             if (nla_put_u32(msg, i + 1, val) < 0)
0680                 goto nla_put_failure;
0681             n++;
0682         }
0683         if (n)
0684             nla_nest_end(msg, nest);
0685         else
0686             nla_nest_cancel(msg, nest);
0687     }
0688 
0689     {
0690         struct tcp_fastopen_metrics tfom_copy[1], *tfom;
0691         unsigned int seq;
0692 
0693         do {
0694             seq = read_seqbegin(&fastopen_seqlock);
0695             tfom_copy[0] = tm->tcpm_fastopen;
0696         } while (read_seqretry(&fastopen_seqlock, seq));
0697 
0698         tfom = tfom_copy;
0699         if (tfom->mss &&
0700             nla_put_u16(msg, TCP_METRICS_ATTR_FOPEN_MSS,
0701                 tfom->mss) < 0)
0702             goto nla_put_failure;
0703         if (tfom->syn_loss &&
0704             (nla_put_u16(msg, TCP_METRICS_ATTR_FOPEN_SYN_DROPS,
0705                 tfom->syn_loss) < 0 ||
0706              nla_put_msecs(msg, TCP_METRICS_ATTR_FOPEN_SYN_DROP_TS,
0707                 jiffies - tfom->last_syn_loss,
0708                 TCP_METRICS_ATTR_PAD) < 0))
0709             goto nla_put_failure;
0710         if (tfom->cookie.len > 0 &&
0711             nla_put(msg, TCP_METRICS_ATTR_FOPEN_COOKIE,
0712                 tfom->cookie.len, tfom->cookie.val) < 0)
0713             goto nla_put_failure;
0714     }
0715 
0716     return 0;
0717 
0718 nla_put_failure:
0719     return -EMSGSIZE;
0720 }
0721 
0722 static int tcp_metrics_dump_info(struct sk_buff *skb,
0723                  struct netlink_callback *cb,
0724                  struct tcp_metrics_block *tm)
0725 {
0726     void *hdr;
0727 
0728     hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
0729               &tcp_metrics_nl_family, NLM_F_MULTI,
0730               TCP_METRICS_CMD_GET);
0731     if (!hdr)
0732         return -EMSGSIZE;
0733 
0734     if (tcp_metrics_fill_info(skb, tm) < 0)
0735         goto nla_put_failure;
0736 
0737     genlmsg_end(skb, hdr);
0738     return 0;
0739 
0740 nla_put_failure:
0741     genlmsg_cancel(skb, hdr);
0742     return -EMSGSIZE;
0743 }
0744 
0745 static int tcp_metrics_nl_dump(struct sk_buff *skb,
0746                    struct netlink_callback *cb)
0747 {
0748     struct net *net = sock_net(skb->sk);
0749     unsigned int max_rows = 1U << tcp_metrics_hash_log;
0750     unsigned int row, s_row = cb->args[0];
0751     int s_col = cb->args[1], col = s_col;
0752 
0753     for (row = s_row; row < max_rows; row++, s_col = 0) {
0754         struct tcp_metrics_block *tm;
0755         struct tcpm_hash_bucket *hb = tcp_metrics_hash + row;
0756 
0757         rcu_read_lock();
0758         for (col = 0, tm = rcu_dereference(hb->chain); tm;
0759              tm = rcu_dereference(tm->tcpm_next), col++) {
0760             if (!net_eq(tm_net(tm), net))
0761                 continue;
0762             if (col < s_col)
0763                 continue;
0764             if (tcp_metrics_dump_info(skb, cb, tm) < 0) {
0765                 rcu_read_unlock();
0766                 goto done;
0767             }
0768         }
0769         rcu_read_unlock();
0770     }
0771 
0772 done:
0773     cb->args[0] = row;
0774     cb->args[1] = col;
0775     return skb->len;
0776 }
0777 
0778 static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
0779                unsigned int *hash, int optional, int v4, int v6)
0780 {
0781     struct nlattr *a;
0782 
0783     a = info->attrs[v4];
0784     if (a) {
0785         inetpeer_set_addr_v4(addr, nla_get_in_addr(a));
0786         if (hash)
0787             *hash = ipv4_addr_hash(inetpeer_get_addr_v4(addr));
0788         return 0;
0789     }
0790     a = info->attrs[v6];
0791     if (a) {
0792         struct in6_addr in6;
0793 
0794         if (nla_len(a) != sizeof(struct in6_addr))
0795             return -EINVAL;
0796         in6 = nla_get_in6_addr(a);
0797         inetpeer_set_addr_v6(addr, &in6);
0798         if (hash)
0799             *hash = ipv6_addr_hash(inetpeer_get_addr_v6(addr));
0800         return 0;
0801     }
0802     return optional ? 1 : -EAFNOSUPPORT;
0803 }
0804 
0805 static int parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
0806              unsigned int *hash, int optional)
0807 {
0808     return __parse_nl_addr(info, addr, hash, optional,
0809                    TCP_METRICS_ATTR_ADDR_IPV4,
0810                    TCP_METRICS_ATTR_ADDR_IPV6);
0811 }
0812 
0813 static int parse_nl_saddr(struct genl_info *info, struct inetpeer_addr *addr)
0814 {
0815     return __parse_nl_addr(info, addr, NULL, 0,
0816                    TCP_METRICS_ATTR_SADDR_IPV4,
0817                    TCP_METRICS_ATTR_SADDR_IPV6);
0818 }
0819 
0820 static int tcp_metrics_nl_cmd_get(struct sk_buff *skb, struct genl_info *info)
0821 {
0822     struct tcp_metrics_block *tm;
0823     struct inetpeer_addr saddr, daddr;
0824     unsigned int hash;
0825     struct sk_buff *msg;
0826     struct net *net = genl_info_net(info);
0827     void *reply;
0828     int ret;
0829     bool src = true;
0830 
0831     ret = parse_nl_addr(info, &daddr, &hash, 0);
0832     if (ret < 0)
0833         return ret;
0834 
0835     ret = parse_nl_saddr(info, &saddr);
0836     if (ret < 0)
0837         src = false;
0838 
0839     msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
0840     if (!msg)
0841         return -ENOMEM;
0842 
0843     reply = genlmsg_put_reply(msg, info, &tcp_metrics_nl_family, 0,
0844                   info->genlhdr->cmd);
0845     if (!reply)
0846         goto nla_put_failure;
0847 
0848     hash ^= net_hash_mix(net);
0849     hash = hash_32(hash, tcp_metrics_hash_log);
0850     ret = -ESRCH;
0851     rcu_read_lock();
0852     for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
0853          tm = rcu_dereference(tm->tcpm_next)) {
0854         if (addr_same(&tm->tcpm_daddr, &daddr) &&
0855             (!src || addr_same(&tm->tcpm_saddr, &saddr)) &&
0856             net_eq(tm_net(tm), net)) {
0857             ret = tcp_metrics_fill_info(msg, tm);
0858             break;
0859         }
0860     }
0861     rcu_read_unlock();
0862     if (ret < 0)
0863         goto out_free;
0864 
0865     genlmsg_end(msg, reply);
0866     return genlmsg_reply(msg, info);
0867 
0868 nla_put_failure:
0869     ret = -EMSGSIZE;
0870 
0871 out_free:
0872     nlmsg_free(msg);
0873     return ret;
0874 }
0875 
0876 static void tcp_metrics_flush_all(struct net *net)
0877 {
0878     unsigned int max_rows = 1U << tcp_metrics_hash_log;
0879     struct tcpm_hash_bucket *hb = tcp_metrics_hash;
0880     struct tcp_metrics_block *tm;
0881     unsigned int row;
0882 
0883     for (row = 0; row < max_rows; row++, hb++) {
0884         struct tcp_metrics_block __rcu **pp;
0885         bool match;
0886 
0887         spin_lock_bh(&tcp_metrics_lock);
0888         pp = &hb->chain;
0889         for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) {
0890             match = net ? net_eq(tm_net(tm), net) :
0891                 !refcount_read(&tm_net(tm)->ns.count);
0892             if (match) {
0893                 *pp = tm->tcpm_next;
0894                 kfree_rcu(tm, rcu_head);
0895             } else {
0896                 pp = &tm->tcpm_next;
0897             }
0898         }
0899         spin_unlock_bh(&tcp_metrics_lock);
0900     }
0901 }
0902 
0903 static int tcp_metrics_nl_cmd_del(struct sk_buff *skb, struct genl_info *info)
0904 {
0905     struct tcpm_hash_bucket *hb;
0906     struct tcp_metrics_block *tm;
0907     struct tcp_metrics_block __rcu **pp;
0908     struct inetpeer_addr saddr, daddr;
0909     unsigned int hash;
0910     struct net *net = genl_info_net(info);
0911     int ret;
0912     bool src = true, found = false;
0913 
0914     ret = parse_nl_addr(info, &daddr, &hash, 1);
0915     if (ret < 0)
0916         return ret;
0917     if (ret > 0) {
0918         tcp_metrics_flush_all(net);
0919         return 0;
0920     }
0921     ret = parse_nl_saddr(info, &saddr);
0922     if (ret < 0)
0923         src = false;
0924 
0925     hash ^= net_hash_mix(net);
0926     hash = hash_32(hash, tcp_metrics_hash_log);
0927     hb = tcp_metrics_hash + hash;
0928     pp = &hb->chain;
0929     spin_lock_bh(&tcp_metrics_lock);
0930     for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) {
0931         if (addr_same(&tm->tcpm_daddr, &daddr) &&
0932             (!src || addr_same(&tm->tcpm_saddr, &saddr)) &&
0933             net_eq(tm_net(tm), net)) {
0934             *pp = tm->tcpm_next;
0935             kfree_rcu(tm, rcu_head);
0936             found = true;
0937         } else {
0938             pp = &tm->tcpm_next;
0939         }
0940     }
0941     spin_unlock_bh(&tcp_metrics_lock);
0942     if (!found)
0943         return -ESRCH;
0944     return 0;
0945 }
0946 
0947 static const struct genl_small_ops tcp_metrics_nl_ops[] = {
0948     {
0949         .cmd = TCP_METRICS_CMD_GET,
0950         .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
0951         .doit = tcp_metrics_nl_cmd_get,
0952         .dumpit = tcp_metrics_nl_dump,
0953     },
0954     {
0955         .cmd = TCP_METRICS_CMD_DEL,
0956         .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
0957         .doit = tcp_metrics_nl_cmd_del,
0958         .flags = GENL_ADMIN_PERM,
0959     },
0960 };
0961 
0962 static struct genl_family tcp_metrics_nl_family __ro_after_init = {
0963     .hdrsize    = 0,
0964     .name       = TCP_METRICS_GENL_NAME,
0965     .version    = TCP_METRICS_GENL_VERSION,
0966     .maxattr    = TCP_METRICS_ATTR_MAX,
0967     .policy = tcp_metrics_nl_policy,
0968     .netnsok    = true,
0969     .module     = THIS_MODULE,
0970     .small_ops  = tcp_metrics_nl_ops,
0971     .n_small_ops    = ARRAY_SIZE(tcp_metrics_nl_ops),
0972 };
0973 
0974 static unsigned int tcpmhash_entries;
0975 static int __init set_tcpmhash_entries(char *str)
0976 {
0977     ssize_t ret;
0978 
0979     if (!str)
0980         return 0;
0981 
0982     ret = kstrtouint(str, 0, &tcpmhash_entries);
0983     if (ret)
0984         return 0;
0985 
0986     return 1;
0987 }
0988 __setup("tcpmhash_entries=", set_tcpmhash_entries);
0989 
0990 static int __net_init tcp_net_metrics_init(struct net *net)
0991 {
0992     size_t size;
0993     unsigned int slots;
0994 
0995     if (!net_eq(net, &init_net))
0996         return 0;
0997 
0998     slots = tcpmhash_entries;
0999     if (!slots) {
1000         if (totalram_pages() >= 128 * 1024)
1001             slots = 16 * 1024;
1002         else
1003             slots = 8 * 1024;
1004     }
1005 
1006     tcp_metrics_hash_log = order_base_2(slots);
1007     size = sizeof(struct tcpm_hash_bucket) << tcp_metrics_hash_log;
1008 
1009     tcp_metrics_hash = kvzalloc(size, GFP_KERNEL);
1010     if (!tcp_metrics_hash)
1011         return -ENOMEM;
1012 
1013     return 0;
1014 }
1015 
1016 static void __net_exit tcp_net_metrics_exit_batch(struct list_head *net_exit_list)
1017 {
1018     tcp_metrics_flush_all(NULL);
1019 }
1020 
1021 static __net_initdata struct pernet_operations tcp_net_metrics_ops = {
1022     .init       =   tcp_net_metrics_init,
1023     .exit_batch =   tcp_net_metrics_exit_batch,
1024 };
1025 
1026 void __init tcp_metrics_init(void)
1027 {
1028     int ret;
1029 
1030     ret = register_pernet_subsys(&tcp_net_metrics_ops);
1031     if (ret < 0)
1032         panic("Could not allocate the tcp_metrics hash table\n");
1033 
1034     ret = genl_register_family(&tcp_metrics_nl_family);
1035     if (ret < 0)
1036         panic("Could not register tcp_metrics generic netlink\n");
1037 }