Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-or-later
0002 /*
0003  * inet_diag.c  Module for monitoring INET transport protocols sockets.
0004  *
0005  * Authors: Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
0006  */
0007 
0008 #include <linux/kernel.h>
0009 #include <linux/module.h>
0010 #include <linux/types.h>
0011 #include <linux/fcntl.h>
0012 #include <linux/random.h>
0013 #include <linux/slab.h>
0014 #include <linux/cache.h>
0015 #include <linux/init.h>
0016 #include <linux/time.h>
0017 
0018 #include <net/icmp.h>
0019 #include <net/tcp.h>
0020 #include <net/ipv6.h>
0021 #include <net/inet_common.h>
0022 #include <net/inet_connection_sock.h>
0023 #include <net/inet_hashtables.h>
0024 #include <net/inet_timewait_sock.h>
0025 #include <net/inet6_hashtables.h>
0026 #include <net/bpf_sk_storage.h>
0027 #include <net/netlink.h>
0028 
0029 #include <linux/inet.h>
0030 #include <linux/stddef.h>
0031 
0032 #include <linux/inet_diag.h>
0033 #include <linux/sock_diag.h>
0034 
0035 static const struct inet_diag_handler **inet_diag_table;
0036 
0037 struct inet_diag_entry {
0038     const __be32 *saddr;
0039     const __be32 *daddr;
0040     u16 sport;
0041     u16 dport;
0042     u16 family;
0043     u16 userlocks;
0044     u32 ifindex;
0045     u32 mark;
0046 #ifdef CONFIG_SOCK_CGROUP_DATA
0047     u64 cgroup_id;
0048 #endif
0049 };
0050 
0051 static DEFINE_MUTEX(inet_diag_table_mutex);
0052 
0053 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
0054 {
0055     if (proto < 0 || proto >= IPPROTO_MAX) {
0056         mutex_lock(&inet_diag_table_mutex);
0057         return ERR_PTR(-ENOENT);
0058     }
0059 
0060     if (!inet_diag_table[proto])
0061         sock_load_diag_module(AF_INET, proto);
0062 
0063     mutex_lock(&inet_diag_table_mutex);
0064     if (!inet_diag_table[proto])
0065         return ERR_PTR(-ENOENT);
0066 
0067     return inet_diag_table[proto];
0068 }
0069 
0070 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
0071 {
0072     mutex_unlock(&inet_diag_table_mutex);
0073 }
0074 
0075 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
0076 {
0077     r->idiag_family = sk->sk_family;
0078 
0079     r->id.idiag_sport = htons(sk->sk_num);
0080     r->id.idiag_dport = sk->sk_dport;
0081     r->id.idiag_if = sk->sk_bound_dev_if;
0082     sock_diag_save_cookie(sk, r->id.idiag_cookie);
0083 
0084 #if IS_ENABLED(CONFIG_IPV6)
0085     if (sk->sk_family == AF_INET6) {
0086         *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
0087         *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
0088     } else
0089 #endif
0090     {
0091     memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
0092     memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
0093 
0094     r->id.idiag_src[0] = sk->sk_rcv_saddr;
0095     r->id.idiag_dst[0] = sk->sk_daddr;
0096     }
0097 }
0098 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
0099 
0100 static size_t inet_sk_attr_size(struct sock *sk,
0101                 const struct inet_diag_req_v2 *req,
0102                 bool net_admin)
0103 {
0104     const struct inet_diag_handler *handler;
0105     size_t aux = 0;
0106 
0107     handler = inet_diag_table[req->sdiag_protocol];
0108     if (handler && handler->idiag_get_aux_size)
0109         aux = handler->idiag_get_aux_size(sk, net_admin);
0110 
0111     return    nla_total_size(sizeof(struct tcp_info))
0112         + nla_total_size(sizeof(struct inet_diag_msg))
0113         + inet_diag_msg_attrs_size()
0114         + nla_total_size(sizeof(struct inet_diag_meminfo))
0115         + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
0116         + nla_total_size(TCP_CA_NAME_MAX)
0117         + nla_total_size(sizeof(struct tcpvegas_info))
0118         + aux
0119         + 64;
0120 }
0121 
0122 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
0123                  struct inet_diag_msg *r, int ext,
0124                  struct user_namespace *user_ns,
0125                  bool net_admin)
0126 {
0127     const struct inet_sock *inet = inet_sk(sk);
0128     struct inet_diag_sockopt inet_sockopt;
0129 
0130     if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
0131         goto errout;
0132 
0133     /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
0134      * hence this needs to be included regardless of socket family.
0135      */
0136     if (ext & (1 << (INET_DIAG_TOS - 1)))
0137         if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
0138             goto errout;
0139 
0140 #if IS_ENABLED(CONFIG_IPV6)
0141     if (r->idiag_family == AF_INET6) {
0142         if (ext & (1 << (INET_DIAG_TCLASS - 1)))
0143             if (nla_put_u8(skb, INET_DIAG_TCLASS,
0144                        inet6_sk(sk)->tclass) < 0)
0145                 goto errout;
0146 
0147         if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
0148             nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
0149             goto errout;
0150     }
0151 #endif
0152 
0153     if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
0154         goto errout;
0155 
0156     if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
0157         ext & (1 << (INET_DIAG_TCLASS - 1))) {
0158         u32 classid = 0;
0159 
0160 #ifdef CONFIG_SOCK_CGROUP_DATA
0161         classid = sock_cgroup_classid(&sk->sk_cgrp_data);
0162 #endif
0163         /* Fallback to socket priority if class id isn't set.
0164          * Classful qdiscs use it as direct reference to class.
0165          * For cgroup2 classid is always zero.
0166          */
0167         if (!classid)
0168             classid = sk->sk_priority;
0169 
0170         if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
0171             goto errout;
0172     }
0173 
0174 #ifdef CONFIG_SOCK_CGROUP_DATA
0175     if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID,
0176                   cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)),
0177                   INET_DIAG_PAD))
0178         goto errout;
0179 #endif
0180 
0181     r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
0182     r->idiag_inode = sock_i_ino(sk);
0183 
0184     memset(&inet_sockopt, 0, sizeof(inet_sockopt));
0185     inet_sockopt.recverr    = inet->recverr;
0186     inet_sockopt.is_icsk    = inet->is_icsk;
0187     inet_sockopt.freebind   = inet->freebind;
0188     inet_sockopt.hdrincl    = inet->hdrincl;
0189     inet_sockopt.mc_loop    = inet->mc_loop;
0190     inet_sockopt.transparent = inet->transparent;
0191     inet_sockopt.mc_all = inet->mc_all;
0192     inet_sockopt.nodefrag   = inet->nodefrag;
0193     inet_sockopt.bind_address_no_port = inet->bind_address_no_port;
0194     inet_sockopt.recverr_rfc4884 = inet->recverr_rfc4884;
0195     inet_sockopt.defer_connect = inet->defer_connect;
0196     if (nla_put(skb, INET_DIAG_SOCKOPT, sizeof(inet_sockopt),
0197             &inet_sockopt))
0198         goto errout;
0199 
0200     return 0;
0201 errout:
0202     return 1;
0203 }
0204 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
0205 
0206 static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
0207                  struct nlattr **req_nlas)
0208 {
0209     struct nlattr *nla;
0210     int remaining;
0211 
0212     nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
0213         int type = nla_type(nla);
0214 
0215         if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32))
0216             return -EINVAL;
0217 
0218         if (type < __INET_DIAG_REQ_MAX)
0219             req_nlas[type] = nla;
0220     }
0221     return 0;
0222 }
0223 
0224 static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
0225                   const struct inet_diag_dump_data *data)
0226 {
0227     if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
0228         return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
0229     return req->sdiag_protocol;
0230 }
0231 
0232 #define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
0233 
0234 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
0235               struct sk_buff *skb, struct netlink_callback *cb,
0236               const struct inet_diag_req_v2 *req,
0237               u16 nlmsg_flags, bool net_admin)
0238 {
0239     const struct tcp_congestion_ops *ca_ops;
0240     const struct inet_diag_handler *handler;
0241     struct inet_diag_dump_data *cb_data;
0242     int ext = req->idiag_ext;
0243     struct inet_diag_msg *r;
0244     struct nlmsghdr  *nlh;
0245     struct nlattr *attr;
0246     void *info = NULL;
0247 
0248     cb_data = cb->data;
0249     handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)];
0250     BUG_ON(!handler);
0251 
0252     nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
0253             cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
0254     if (!nlh)
0255         return -EMSGSIZE;
0256 
0257     r = nlmsg_data(nlh);
0258     BUG_ON(!sk_fullsock(sk));
0259 
0260     inet_diag_msg_common_fill(r, sk);
0261     r->idiag_state = sk->sk_state;
0262     r->idiag_timer = 0;
0263     r->idiag_retrans = 0;
0264     r->idiag_expires = 0;
0265 
0266     if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
0267                      sk_user_ns(NETLINK_CB(cb->skb).sk),
0268                      net_admin))
0269         goto errout;
0270 
0271     if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
0272         struct inet_diag_meminfo minfo = {
0273             .idiag_rmem = sk_rmem_alloc_get(sk),
0274             .idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
0275             .idiag_fmem = sk_forward_alloc_get(sk),
0276             .idiag_tmem = sk_wmem_alloc_get(sk),
0277         };
0278 
0279         if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
0280             goto errout;
0281     }
0282 
0283     if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
0284         if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
0285             goto errout;
0286 
0287     /*
0288      * RAW sockets might have user-defined protocols assigned,
0289      * so report the one supplied on socket creation.
0290      */
0291     if (sk->sk_type == SOCK_RAW) {
0292         if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
0293             goto errout;
0294     }
0295 
0296     if (!icsk) {
0297         handler->idiag_get_info(sk, r, NULL);
0298         goto out;
0299     }
0300 
0301     if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
0302         icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT ||
0303         icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
0304         r->idiag_timer = 1;
0305         r->idiag_retrans = icsk->icsk_retransmits;
0306         r->idiag_expires =
0307             jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
0308     } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
0309         r->idiag_timer = 4;
0310         r->idiag_retrans = icsk->icsk_probes_out;
0311         r->idiag_expires =
0312             jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
0313     } else if (timer_pending(&sk->sk_timer)) {
0314         r->idiag_timer = 2;
0315         r->idiag_retrans = icsk->icsk_probes_out;
0316         r->idiag_expires =
0317             jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies);
0318     }
0319 
0320     if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
0321         attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
0322                      handler->idiag_info_size,
0323                      INET_DIAG_PAD);
0324         if (!attr)
0325             goto errout;
0326 
0327         info = nla_data(attr);
0328     }
0329 
0330     if (ext & (1 << (INET_DIAG_CONG - 1))) {
0331         int err = 0;
0332 
0333         rcu_read_lock();
0334         ca_ops = READ_ONCE(icsk->icsk_ca_ops);
0335         if (ca_ops)
0336             err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
0337         rcu_read_unlock();
0338         if (err < 0)
0339             goto errout;
0340     }
0341 
0342     handler->idiag_get_info(sk, r, info);
0343 
0344     if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
0345         if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
0346             goto errout;
0347 
0348     if (sk->sk_state < TCP_TIME_WAIT) {
0349         union tcp_cc_info info;
0350         size_t sz = 0;
0351         int attr;
0352 
0353         rcu_read_lock();
0354         ca_ops = READ_ONCE(icsk->icsk_ca_ops);
0355         if (ca_ops && ca_ops->get_info)
0356             sz = ca_ops->get_info(sk, ext, &attr, &info);
0357         rcu_read_unlock();
0358         if (sz && nla_put(skb, attr, sz, &info) < 0)
0359             goto errout;
0360     }
0361 
0362     /* Keep it at the end for potential retry with a larger skb,
0363      * or else do best-effort fitting, which is only done for the
0364      * first_nlmsg.
0365      */
0366     if (cb_data->bpf_stg_diag) {
0367         bool first_nlmsg = ((unsigned char *)nlh == skb->data);
0368         unsigned int prev_min_dump_alloc;
0369         unsigned int total_nla_size = 0;
0370         unsigned int msg_len;
0371         int err;
0372 
0373         msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
0374         err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
0375                           INET_DIAG_SK_BPF_STORAGES,
0376                           &total_nla_size);
0377 
0378         if (!err)
0379             goto out;
0380 
0381         total_nla_size += msg_len;
0382         prev_min_dump_alloc = cb->min_dump_alloc;
0383         if (total_nla_size > prev_min_dump_alloc)
0384             cb->min_dump_alloc = min_t(u32, total_nla_size,
0385                            MAX_DUMP_ALLOC_SIZE);
0386 
0387         if (!first_nlmsg)
0388             goto errout;
0389 
0390         if (cb->min_dump_alloc > prev_min_dump_alloc)
0391             /* Retry with pskb_expand_head() with
0392              * __GFP_DIRECT_RECLAIM
0393              */
0394             goto errout;
0395 
0396         WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
0397 
0398         /* Send what we have for this sk
0399          * and move on to the next sk in the following
0400          * dump()
0401          */
0402     }
0403 
0404 out:
0405     nlmsg_end(skb, nlh);
0406     return 0;
0407 
0408 errout:
0409     nlmsg_cancel(skb, nlh);
0410     return -EMSGSIZE;
0411 }
0412 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
0413 
0414 static int inet_twsk_diag_fill(struct sock *sk,
0415                    struct sk_buff *skb,
0416                    struct netlink_callback *cb,
0417                    u16 nlmsg_flags, bool net_admin)
0418 {
0419     struct inet_timewait_sock *tw = inet_twsk(sk);
0420     struct inet_diag_msg *r;
0421     struct nlmsghdr *nlh;
0422     long tmo;
0423 
0424     nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
0425             cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
0426             sizeof(*r), nlmsg_flags);
0427     if (!nlh)
0428         return -EMSGSIZE;
0429 
0430     r = nlmsg_data(nlh);
0431     BUG_ON(tw->tw_state != TCP_TIME_WAIT);
0432 
0433     inet_diag_msg_common_fill(r, sk);
0434     r->idiag_retrans      = 0;
0435 
0436     r->idiag_state        = tw->tw_substate;
0437     r->idiag_timer        = 3;
0438     tmo = tw->tw_timer.expires - jiffies;
0439     r->idiag_expires      = jiffies_delta_to_msecs(tmo);
0440     r->idiag_rqueue       = 0;
0441     r->idiag_wqueue       = 0;
0442     r->idiag_uid          = 0;
0443     r->idiag_inode        = 0;
0444 
0445     if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
0446                      tw->tw_mark)) {
0447         nlmsg_cancel(skb, nlh);
0448         return -EMSGSIZE;
0449     }
0450 
0451     nlmsg_end(skb, nlh);
0452     return 0;
0453 }
0454 
0455 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
0456                   struct netlink_callback *cb,
0457                   u16 nlmsg_flags, bool net_admin)
0458 {
0459     struct request_sock *reqsk = inet_reqsk(sk);
0460     struct inet_diag_msg *r;
0461     struct nlmsghdr *nlh;
0462     long tmo;
0463 
0464     nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
0465             cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
0466     if (!nlh)
0467         return -EMSGSIZE;
0468 
0469     r = nlmsg_data(nlh);
0470     inet_diag_msg_common_fill(r, sk);
0471     r->idiag_state = TCP_SYN_RECV;
0472     r->idiag_timer = 1;
0473     r->idiag_retrans = reqsk->num_retrans;
0474 
0475     BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
0476              offsetof(struct sock, sk_cookie));
0477 
0478     tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
0479     r->idiag_expires = jiffies_delta_to_msecs(tmo);
0480     r->idiag_rqueue = 0;
0481     r->idiag_wqueue = 0;
0482     r->idiag_uid    = 0;
0483     r->idiag_inode  = 0;
0484 
0485     if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
0486                      inet_rsk(reqsk)->ir_mark)) {
0487         nlmsg_cancel(skb, nlh);
0488         return -EMSGSIZE;
0489     }
0490 
0491     nlmsg_end(skb, nlh);
0492     return 0;
0493 }
0494 
0495 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
0496             struct netlink_callback *cb,
0497             const struct inet_diag_req_v2 *r,
0498             u16 nlmsg_flags, bool net_admin)
0499 {
0500     if (sk->sk_state == TCP_TIME_WAIT)
0501         return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
0502 
0503     if (sk->sk_state == TCP_NEW_SYN_RECV)
0504         return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
0505 
0506     return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
0507                  net_admin);
0508 }
0509 
0510 struct sock *inet_diag_find_one_icsk(struct net *net,
0511                      struct inet_hashinfo *hashinfo,
0512                      const struct inet_diag_req_v2 *req)
0513 {
0514     struct sock *sk;
0515 
0516     rcu_read_lock();
0517     if (req->sdiag_family == AF_INET)
0518         sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
0519                  req->id.idiag_dport, req->id.idiag_src[0],
0520                  req->id.idiag_sport, req->id.idiag_if);
0521 #if IS_ENABLED(CONFIG_IPV6)
0522     else if (req->sdiag_family == AF_INET6) {
0523         if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
0524             ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
0525             sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
0526                      req->id.idiag_dport, req->id.idiag_src[3],
0527                      req->id.idiag_sport, req->id.idiag_if);
0528         else
0529             sk = inet6_lookup(net, hashinfo, NULL, 0,
0530                       (struct in6_addr *)req->id.idiag_dst,
0531                       req->id.idiag_dport,
0532                       (struct in6_addr *)req->id.idiag_src,
0533                       req->id.idiag_sport,
0534                       req->id.idiag_if);
0535     }
0536 #endif
0537     else {
0538         rcu_read_unlock();
0539         return ERR_PTR(-EINVAL);
0540     }
0541     rcu_read_unlock();
0542     if (!sk)
0543         return ERR_PTR(-ENOENT);
0544 
0545     if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
0546         sock_gen_put(sk);
0547         return ERR_PTR(-ENOENT);
0548     }
0549 
0550     return sk;
0551 }
0552 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
0553 
0554 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
0555                 struct netlink_callback *cb,
0556                 const struct inet_diag_req_v2 *req)
0557 {
0558     struct sk_buff *in_skb = cb->skb;
0559     bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
0560     struct net *net = sock_net(in_skb->sk);
0561     struct sk_buff *rep;
0562     struct sock *sk;
0563     int err;
0564 
0565     sk = inet_diag_find_one_icsk(net, hashinfo, req);
0566     if (IS_ERR(sk))
0567         return PTR_ERR(sk);
0568 
0569     rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL);
0570     if (!rep) {
0571         err = -ENOMEM;
0572         goto out;
0573     }
0574 
0575     err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
0576     if (err < 0) {
0577         WARN_ON(err == -EMSGSIZE);
0578         nlmsg_free(rep);
0579         goto out;
0580     }
0581     err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid);
0582 
0583 out:
0584     if (sk)
0585         sock_gen_put(sk);
0586 
0587     return err;
0588 }
0589 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
0590 
0591 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
0592                    const struct nlmsghdr *nlh,
0593                    int hdrlen,
0594                    const struct inet_diag_req_v2 *req)
0595 {
0596     const struct inet_diag_handler *handler;
0597     struct inet_diag_dump_data dump_data;
0598     int err, protocol;
0599 
0600     memset(&dump_data, 0, sizeof(dump_data));
0601     err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
0602     if (err)
0603         return err;
0604 
0605     protocol = inet_diag_get_protocol(req, &dump_data);
0606 
0607     handler = inet_diag_lock_handler(protocol);
0608     if (IS_ERR(handler)) {
0609         err = PTR_ERR(handler);
0610     } else if (cmd == SOCK_DIAG_BY_FAMILY) {
0611         struct netlink_callback cb = {
0612             .nlh = nlh,
0613             .skb = in_skb,
0614             .data = &dump_data,
0615         };
0616         err = handler->dump_one(&cb, req);
0617     } else if (cmd == SOCK_DESTROY && handler->destroy) {
0618         err = handler->destroy(in_skb, req);
0619     } else {
0620         err = -EOPNOTSUPP;
0621     }
0622     inet_diag_unlock_handler(handler);
0623 
0624     return err;
0625 }
0626 
0627 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
0628 {
0629     int words = bits >> 5;
0630 
0631     bits &= 0x1f;
0632 
0633     if (words) {
0634         if (memcmp(a1, a2, words << 2))
0635             return 0;
0636     }
0637     if (bits) {
0638         __be32 w1, w2;
0639         __be32 mask;
0640 
0641         w1 = a1[words];
0642         w2 = a2[words];
0643 
0644         mask = htonl((0xffffffff) << (32 - bits));
0645 
0646         if ((w1 ^ w2) & mask)
0647             return 0;
0648     }
0649 
0650     return 1;
0651 }
0652 
0653 static int inet_diag_bc_run(const struct nlattr *_bc,
0654                 const struct inet_diag_entry *entry)
0655 {
0656     const void *bc = nla_data(_bc);
0657     int len = nla_len(_bc);
0658 
0659     while (len > 0) {
0660         int yes = 1;
0661         const struct inet_diag_bc_op *op = bc;
0662 
0663         switch (op->code) {
0664         case INET_DIAG_BC_NOP:
0665             break;
0666         case INET_DIAG_BC_JMP:
0667             yes = 0;
0668             break;
0669         case INET_DIAG_BC_S_EQ:
0670             yes = entry->sport == op[1].no;
0671             break;
0672         case INET_DIAG_BC_S_GE:
0673             yes = entry->sport >= op[1].no;
0674             break;
0675         case INET_DIAG_BC_S_LE:
0676             yes = entry->sport <= op[1].no;
0677             break;
0678         case INET_DIAG_BC_D_EQ:
0679             yes = entry->dport == op[1].no;
0680             break;
0681         case INET_DIAG_BC_D_GE:
0682             yes = entry->dport >= op[1].no;
0683             break;
0684         case INET_DIAG_BC_D_LE:
0685             yes = entry->dport <= op[1].no;
0686             break;
0687         case INET_DIAG_BC_AUTO:
0688             yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
0689             break;
0690         case INET_DIAG_BC_S_COND:
0691         case INET_DIAG_BC_D_COND: {
0692             const struct inet_diag_hostcond *cond;
0693             const __be32 *addr;
0694 
0695             cond = (const struct inet_diag_hostcond *)(op + 1);
0696             if (cond->port != -1 &&
0697                 cond->port != (op->code == INET_DIAG_BC_S_COND ?
0698                          entry->sport : entry->dport)) {
0699                 yes = 0;
0700                 break;
0701             }
0702 
0703             if (op->code == INET_DIAG_BC_S_COND)
0704                 addr = entry->saddr;
0705             else
0706                 addr = entry->daddr;
0707 
0708             if (cond->family != AF_UNSPEC &&
0709                 cond->family != entry->family) {
0710                 if (entry->family == AF_INET6 &&
0711                     cond->family == AF_INET) {
0712                     if (addr[0] == 0 && addr[1] == 0 &&
0713                         addr[2] == htonl(0xffff) &&
0714                         bitstring_match(addr + 3,
0715                                 cond->addr,
0716                                 cond->prefix_len))
0717                         break;
0718                 }
0719                 yes = 0;
0720                 break;
0721             }
0722 
0723             if (cond->prefix_len == 0)
0724                 break;
0725             if (bitstring_match(addr, cond->addr,
0726                         cond->prefix_len))
0727                 break;
0728             yes = 0;
0729             break;
0730         }
0731         case INET_DIAG_BC_DEV_COND: {
0732             u32 ifindex;
0733 
0734             ifindex = *((const u32 *)(op + 1));
0735             if (ifindex != entry->ifindex)
0736                 yes = 0;
0737             break;
0738         }
0739         case INET_DIAG_BC_MARK_COND: {
0740             struct inet_diag_markcond *cond;
0741 
0742             cond = (struct inet_diag_markcond *)(op + 1);
0743             if ((entry->mark & cond->mask) != cond->mark)
0744                 yes = 0;
0745             break;
0746         }
0747 #ifdef CONFIG_SOCK_CGROUP_DATA
0748         case INET_DIAG_BC_CGROUP_COND: {
0749             u64 cgroup_id;
0750 
0751             cgroup_id = get_unaligned((const u64 *)(op + 1));
0752             if (cgroup_id != entry->cgroup_id)
0753                 yes = 0;
0754             break;
0755         }
0756 #endif
0757         }
0758 
0759         if (yes) {
0760             len -= op->yes;
0761             bc += op->yes;
0762         } else {
0763             len -= op->no;
0764             bc += op->no;
0765         }
0766     }
0767     return len == 0;
0768 }
0769 
0770 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
0771  */
0772 static void entry_fill_addrs(struct inet_diag_entry *entry,
0773                  const struct sock *sk)
0774 {
0775 #if IS_ENABLED(CONFIG_IPV6)
0776     if (sk->sk_family == AF_INET6) {
0777         entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
0778         entry->daddr = sk->sk_v6_daddr.s6_addr32;
0779     } else
0780 #endif
0781     {
0782         entry->saddr = &sk->sk_rcv_saddr;
0783         entry->daddr = &sk->sk_daddr;
0784     }
0785 }
0786 
0787 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
0788 {
0789     struct inet_sock *inet = inet_sk(sk);
0790     struct inet_diag_entry entry;
0791 
0792     if (!bc)
0793         return 1;
0794 
0795     entry.family = sk->sk_family;
0796     entry_fill_addrs(&entry, sk);
0797     entry.sport = inet->inet_num;
0798     entry.dport = ntohs(inet->inet_dport);
0799     entry.ifindex = sk->sk_bound_dev_if;
0800     entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
0801     if (sk_fullsock(sk))
0802         entry.mark = sk->sk_mark;
0803     else if (sk->sk_state == TCP_NEW_SYN_RECV)
0804         entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
0805     else if (sk->sk_state == TCP_TIME_WAIT)
0806         entry.mark = inet_twsk(sk)->tw_mark;
0807     else
0808         entry.mark = 0;
0809 #ifdef CONFIG_SOCK_CGROUP_DATA
0810     entry.cgroup_id = sk_fullsock(sk) ?
0811         cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0;
0812 #endif
0813 
0814     return inet_diag_bc_run(bc, &entry);
0815 }
0816 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
0817 
0818 static int valid_cc(const void *bc, int len, int cc)
0819 {
0820     while (len >= 0) {
0821         const struct inet_diag_bc_op *op = bc;
0822 
0823         if (cc > len)
0824             return 0;
0825         if (cc == len)
0826             return 1;
0827         if (op->yes < 4 || op->yes & 3)
0828             return 0;
0829         len -= op->yes;
0830         bc  += op->yes;
0831     }
0832     return 0;
0833 }
0834 
0835 /* data is u32 ifindex */
0836 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
0837               int *min_len)
0838 {
0839     /* Check ifindex space. */
0840     *min_len += sizeof(u32);
0841     if (len < *min_len)
0842         return false;
0843 
0844     return true;
0845 }
0846 /* Validate an inet_diag_hostcond. */
0847 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
0848                int *min_len)
0849 {
0850     struct inet_diag_hostcond *cond;
0851     int addr_len;
0852 
0853     /* Check hostcond space. */
0854     *min_len += sizeof(struct inet_diag_hostcond);
0855     if (len < *min_len)
0856         return false;
0857     cond = (struct inet_diag_hostcond *)(op + 1);
0858 
0859     /* Check address family and address length. */
0860     switch (cond->family) {
0861     case AF_UNSPEC:
0862         addr_len = 0;
0863         break;
0864     case AF_INET:
0865         addr_len = sizeof(struct in_addr);
0866         break;
0867     case AF_INET6:
0868         addr_len = sizeof(struct in6_addr);
0869         break;
0870     default:
0871         return false;
0872     }
0873     *min_len += addr_len;
0874     if (len < *min_len)
0875         return false;
0876 
0877     /* Check prefix length (in bits) vs address length (in bytes). */
0878     if (cond->prefix_len > 8 * addr_len)
0879         return false;
0880 
0881     return true;
0882 }
0883 
0884 /* Validate a port comparison operator. */
0885 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
0886                   int len, int *min_len)
0887 {
0888     /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
0889     *min_len += sizeof(struct inet_diag_bc_op);
0890     if (len < *min_len)
0891         return false;
0892     return true;
0893 }
0894 
0895 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
0896                int *min_len)
0897 {
0898     *min_len += sizeof(struct inet_diag_markcond);
0899     return len >= *min_len;
0900 }
0901 
0902 #ifdef CONFIG_SOCK_CGROUP_DATA
0903 static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len,
0904                  int *min_len)
0905 {
0906     *min_len += sizeof(u64);
0907     return len >= *min_len;
0908 }
0909 #endif
0910 
0911 static int inet_diag_bc_audit(const struct nlattr *attr,
0912                   const struct sk_buff *skb)
0913 {
0914     bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
0915     const void *bytecode, *bc;
0916     int bytecode_len, len;
0917 
0918     if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
0919         return -EINVAL;
0920 
0921     bytecode = bc = nla_data(attr);
0922     len = bytecode_len = nla_len(attr);
0923 
0924     while (len > 0) {
0925         int min_len = sizeof(struct inet_diag_bc_op);
0926         const struct inet_diag_bc_op *op = bc;
0927 
0928         switch (op->code) {
0929         case INET_DIAG_BC_S_COND:
0930         case INET_DIAG_BC_D_COND:
0931             if (!valid_hostcond(bc, len, &min_len))
0932                 return -EINVAL;
0933             break;
0934         case INET_DIAG_BC_DEV_COND:
0935             if (!valid_devcond(bc, len, &min_len))
0936                 return -EINVAL;
0937             break;
0938         case INET_DIAG_BC_S_EQ:
0939         case INET_DIAG_BC_S_GE:
0940         case INET_DIAG_BC_S_LE:
0941         case INET_DIAG_BC_D_EQ:
0942         case INET_DIAG_BC_D_GE:
0943         case INET_DIAG_BC_D_LE:
0944             if (!valid_port_comparison(bc, len, &min_len))
0945                 return -EINVAL;
0946             break;
0947         case INET_DIAG_BC_MARK_COND:
0948             if (!net_admin)
0949                 return -EPERM;
0950             if (!valid_markcond(bc, len, &min_len))
0951                 return -EINVAL;
0952             break;
0953 #ifdef CONFIG_SOCK_CGROUP_DATA
0954         case INET_DIAG_BC_CGROUP_COND:
0955             if (!valid_cgroupcond(bc, len, &min_len))
0956                 return -EINVAL;
0957             break;
0958 #endif
0959         case INET_DIAG_BC_AUTO:
0960         case INET_DIAG_BC_JMP:
0961         case INET_DIAG_BC_NOP:
0962             break;
0963         default:
0964             return -EINVAL;
0965         }
0966 
0967         if (op->code != INET_DIAG_BC_NOP) {
0968             if (op->no < min_len || op->no > len + 4 || op->no & 3)
0969                 return -EINVAL;
0970             if (op->no < len &&
0971                 !valid_cc(bytecode, bytecode_len, len - op->no))
0972                 return -EINVAL;
0973         }
0974 
0975         if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
0976             return -EINVAL;
0977         bc  += op->yes;
0978         len -= op->yes;
0979     }
0980     return len == 0 ? 0 : -EINVAL;
0981 }
0982 
0983 static void twsk_build_assert(void)
0984 {
0985     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
0986              offsetof(struct sock, sk_family));
0987 
0988     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
0989              offsetof(struct inet_sock, inet_num));
0990 
0991     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
0992              offsetof(struct inet_sock, inet_dport));
0993 
0994     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
0995              offsetof(struct inet_sock, inet_rcv_saddr));
0996 
0997     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
0998              offsetof(struct inet_sock, inet_daddr));
0999 
1000 #if IS_ENABLED(CONFIG_IPV6)
1001     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
1002              offsetof(struct sock, sk_v6_rcv_saddr));
1003 
1004     BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
1005              offsetof(struct sock, sk_v6_daddr));
1006 #endif
1007 }
1008 
1009 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
1010              struct netlink_callback *cb,
1011              const struct inet_diag_req_v2 *r)
1012 {
1013     bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
1014     struct inet_diag_dump_data *cb_data = cb->data;
1015     struct net *net = sock_net(skb->sk);
1016     u32 idiag_states = r->idiag_states;
1017     int i, num, s_i, s_num;
1018     struct nlattr *bc;
1019     struct sock *sk;
1020 
1021     bc = cb_data->inet_diag_nla_bc;
1022     if (idiag_states & TCPF_SYN_RECV)
1023         idiag_states |= TCPF_NEW_SYN_RECV;
1024     s_i = cb->args[1];
1025     s_num = num = cb->args[2];
1026 
1027     if (cb->args[0] == 0) {
1028         if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport)
1029             goto skip_listen_ht;
1030 
1031         for (i = s_i; i <= hashinfo->lhash2_mask; i++) {
1032             struct inet_listen_hashbucket *ilb;
1033             struct hlist_nulls_node *node;
1034 
1035             num = 0;
1036             ilb = &hashinfo->lhash2[i];
1037 
1038             spin_lock(&ilb->lock);
1039             sk_nulls_for_each(sk, node, &ilb->nulls_head) {
1040                 struct inet_sock *inet = inet_sk(sk);
1041 
1042                 if (!net_eq(sock_net(sk), net))
1043                     continue;
1044 
1045                 if (num < s_num) {
1046                     num++;
1047                     continue;
1048                 }
1049 
1050                 if (r->sdiag_family != AF_UNSPEC &&
1051                     sk->sk_family != r->sdiag_family)
1052                     goto next_listen;
1053 
1054                 if (r->id.idiag_sport != inet->inet_sport &&
1055                     r->id.idiag_sport)
1056                     goto next_listen;
1057 
1058                 if (!inet_diag_bc_sk(bc, sk))
1059                     goto next_listen;
1060 
1061                 if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
1062                               cb, r, NLM_F_MULTI,
1063                               net_admin) < 0) {
1064                     spin_unlock(&ilb->lock);
1065                     goto done;
1066                 }
1067 
1068 next_listen:
1069                 ++num;
1070             }
1071             spin_unlock(&ilb->lock);
1072 
1073             s_num = 0;
1074         }
1075 skip_listen_ht:
1076         cb->args[0] = 1;
1077         s_i = num = s_num = 0;
1078     }
1079 
1080     if (!(idiag_states & ~TCPF_LISTEN))
1081         goto out;
1082 
1083 #define SKARR_SZ 16
1084     for (i = s_i; i <= hashinfo->ehash_mask; i++) {
1085         struct inet_ehash_bucket *head = &hashinfo->ehash[i];
1086         spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
1087         struct hlist_nulls_node *node;
1088         struct sock *sk_arr[SKARR_SZ];
1089         int num_arr[SKARR_SZ];
1090         int idx, accum, res;
1091 
1092         if (hlist_nulls_empty(&head->chain))
1093             continue;
1094 
1095         if (i > s_i)
1096             s_num = 0;
1097 
1098 next_chunk:
1099         num = 0;
1100         accum = 0;
1101         spin_lock_bh(lock);
1102         sk_nulls_for_each(sk, node, &head->chain) {
1103             int state;
1104 
1105             if (!net_eq(sock_net(sk), net))
1106                 continue;
1107             if (num < s_num)
1108                 goto next_normal;
1109             state = (sk->sk_state == TCP_TIME_WAIT) ?
1110                 inet_twsk(sk)->tw_substate : sk->sk_state;
1111             if (!(idiag_states & (1 << state)))
1112                 goto next_normal;
1113             if (r->sdiag_family != AF_UNSPEC &&
1114                 sk->sk_family != r->sdiag_family)
1115                 goto next_normal;
1116             if (r->id.idiag_sport != htons(sk->sk_num) &&
1117                 r->id.idiag_sport)
1118                 goto next_normal;
1119             if (r->id.idiag_dport != sk->sk_dport &&
1120                 r->id.idiag_dport)
1121                 goto next_normal;
1122             twsk_build_assert();
1123 
1124             if (!inet_diag_bc_sk(bc, sk))
1125                 goto next_normal;
1126 
1127             if (!refcount_inc_not_zero(&sk->sk_refcnt))
1128                 goto next_normal;
1129 
1130             num_arr[accum] = num;
1131             sk_arr[accum] = sk;
1132             if (++accum == SKARR_SZ)
1133                 break;
1134 next_normal:
1135             ++num;
1136         }
1137         spin_unlock_bh(lock);
1138         res = 0;
1139         for (idx = 0; idx < accum; idx++) {
1140             if (res >= 0) {
1141                 res = sk_diag_fill(sk_arr[idx], skb, cb, r,
1142                            NLM_F_MULTI, net_admin);
1143                 if (res < 0)
1144                     num = num_arr[idx];
1145             }
1146             sock_gen_put(sk_arr[idx]);
1147         }
1148         if (res < 0)
1149             break;
1150         cond_resched();
1151         if (accum == SKARR_SZ) {
1152             s_num = num + 1;
1153             goto next_chunk;
1154         }
1155     }
1156 
1157 done:
1158     cb->args[1] = i;
1159     cb->args[2] = num;
1160 out:
1161     ;
1162 }
1163 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1164 
1165 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1166                 const struct inet_diag_req_v2 *r)
1167 {
1168     struct inet_diag_dump_data *cb_data = cb->data;
1169     const struct inet_diag_handler *handler;
1170     u32 prev_min_dump_alloc;
1171     int protocol, err = 0;
1172 
1173     protocol = inet_diag_get_protocol(r, cb_data);
1174 
1175 again:
1176     prev_min_dump_alloc = cb->min_dump_alloc;
1177     handler = inet_diag_lock_handler(protocol);
1178     if (!IS_ERR(handler))
1179         handler->dump(skb, cb, r);
1180     else
1181         err = PTR_ERR(handler);
1182     inet_diag_unlock_handler(handler);
1183 
1184     /* The skb is not large enough to fit one sk info and
1185      * inet_sk_diag_fill() has requested for a larger skb.
1186      */
1187     if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
1188         err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
1189         if (!err)
1190             goto again;
1191     }
1192 
1193     return err ? : skb->len;
1194 }
1195 
1196 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1197 {
1198     return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
1199 }
1200 
1201 static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
1202 {
1203     const struct nlmsghdr *nlh = cb->nlh;
1204     struct inet_diag_dump_data *cb_data;
1205     struct sk_buff *skb = cb->skb;
1206     struct nlattr *nla;
1207     int err;
1208 
1209     cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
1210     if (!cb_data)
1211         return -ENOMEM;
1212 
1213     err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
1214     if (err) {
1215         kfree(cb_data);
1216         return err;
1217     }
1218     nla = cb_data->inet_diag_nla_bc;
1219     if (nla) {
1220         err = inet_diag_bc_audit(nla, skb);
1221         if (err) {
1222             kfree(cb_data);
1223             return err;
1224         }
1225     }
1226 
1227     nla = cb_data->inet_diag_nla_bpf_stgs;
1228     if (nla) {
1229         struct bpf_sk_storage_diag *bpf_stg_diag;
1230 
1231         bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
1232         if (IS_ERR(bpf_stg_diag)) {
1233             kfree(cb_data);
1234             return PTR_ERR(bpf_stg_diag);
1235         }
1236         cb_data->bpf_stg_diag = bpf_stg_diag;
1237     }
1238 
1239     cb->data = cb_data;
1240     return 0;
1241 }
1242 
1243 static int inet_diag_dump_start(struct netlink_callback *cb)
1244 {
1245     return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
1246 }
1247 
1248 static int inet_diag_dump_start_compat(struct netlink_callback *cb)
1249 {
1250     return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
1251 }
1252 
1253 static int inet_diag_dump_done(struct netlink_callback *cb)
1254 {
1255     struct inet_diag_dump_data *cb_data = cb->data;
1256 
1257     bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
1258     kfree(cb->data);
1259 
1260     return 0;
1261 }
1262 
1263 static int inet_diag_type2proto(int type)
1264 {
1265     switch (type) {
1266     case TCPDIAG_GETSOCK:
1267         return IPPROTO_TCP;
1268     case DCCPDIAG_GETSOCK:
1269         return IPPROTO_DCCP;
1270     default:
1271         return 0;
1272     }
1273 }
1274 
1275 static int inet_diag_dump_compat(struct sk_buff *skb,
1276                  struct netlink_callback *cb)
1277 {
1278     struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1279     struct inet_diag_req_v2 req;
1280 
1281     req.sdiag_family = AF_UNSPEC; /* compatibility */
1282     req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1283     req.idiag_ext = rc->idiag_ext;
1284     req.idiag_states = rc->idiag_states;
1285     req.id = rc->id;
1286 
1287     return __inet_diag_dump(skb, cb, &req);
1288 }
1289 
1290 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1291                       const struct nlmsghdr *nlh)
1292 {
1293     struct inet_diag_req *rc = nlmsg_data(nlh);
1294     struct inet_diag_req_v2 req;
1295 
1296     req.sdiag_family = rc->idiag_family;
1297     req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1298     req.idiag_ext = rc->idiag_ext;
1299     req.idiag_states = rc->idiag_states;
1300     req.id = rc->id;
1301 
1302     return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
1303                    sizeof(struct inet_diag_req), &req);
1304 }
1305 
1306 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1307 {
1308     int hdrlen = sizeof(struct inet_diag_req);
1309     struct net *net = sock_net(skb->sk);
1310 
1311     if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1312         nlmsg_len(nlh) < hdrlen)
1313         return -EINVAL;
1314 
1315     if (nlh->nlmsg_flags & NLM_F_DUMP) {
1316         struct netlink_dump_control c = {
1317             .start = inet_diag_dump_start_compat,
1318             .done = inet_diag_dump_done,
1319             .dump = inet_diag_dump_compat,
1320         };
1321         return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1322     }
1323 
1324     return inet_diag_get_exact_compat(skb, nlh);
1325 }
1326 
1327 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1328 {
1329     int hdrlen = sizeof(struct inet_diag_req_v2);
1330     struct net *net = sock_net(skb->sk);
1331 
1332     if (nlmsg_len(h) < hdrlen)
1333         return -EINVAL;
1334 
1335     if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1336         h->nlmsg_flags & NLM_F_DUMP) {
1337         struct netlink_dump_control c = {
1338             .start = inet_diag_dump_start,
1339             .done = inet_diag_dump_done,
1340             .dump = inet_diag_dump,
1341         };
1342         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1343     }
1344 
1345     return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
1346                    nlmsg_data(h));
1347 }
1348 
1349 static
1350 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1351 {
1352     const struct inet_diag_handler *handler;
1353     struct nlmsghdr *nlh;
1354     struct nlattr *attr;
1355     struct inet_diag_msg *r;
1356     void *info = NULL;
1357     int err = 0;
1358 
1359     nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1360     if (!nlh)
1361         return -ENOMEM;
1362 
1363     r = nlmsg_data(nlh);
1364     memset(r, 0, sizeof(*r));
1365     inet_diag_msg_common_fill(r, sk);
1366     if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1367         r->id.idiag_sport = inet_sk(sk)->inet_sport;
1368     r->idiag_state = sk->sk_state;
1369 
1370     if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1371         nlmsg_cancel(skb, nlh);
1372         return err;
1373     }
1374 
1375     handler = inet_diag_lock_handler(sk->sk_protocol);
1376     if (IS_ERR(handler)) {
1377         inet_diag_unlock_handler(handler);
1378         nlmsg_cancel(skb, nlh);
1379         return PTR_ERR(handler);
1380     }
1381 
1382     attr = handler->idiag_info_size
1383         ? nla_reserve_64bit(skb, INET_DIAG_INFO,
1384                     handler->idiag_info_size,
1385                     INET_DIAG_PAD)
1386         : NULL;
1387     if (attr)
1388         info = nla_data(attr);
1389 
1390     handler->idiag_get_info(sk, r, info);
1391     inet_diag_unlock_handler(handler);
1392 
1393     nlmsg_end(skb, nlh);
1394     return 0;
1395 }
1396 
1397 static const struct sock_diag_handler inet_diag_handler = {
1398     .family = AF_INET,
1399     .dump = inet_diag_handler_cmd,
1400     .get_info = inet_diag_handler_get_info,
1401     .destroy = inet_diag_handler_cmd,
1402 };
1403 
1404 static const struct sock_diag_handler inet6_diag_handler = {
1405     .family = AF_INET6,
1406     .dump = inet_diag_handler_cmd,
1407     .get_info = inet_diag_handler_get_info,
1408     .destroy = inet_diag_handler_cmd,
1409 };
1410 
1411 int inet_diag_register(const struct inet_diag_handler *h)
1412 {
1413     const __u16 type = h->idiag_type;
1414     int err = -EINVAL;
1415 
1416     if (type >= IPPROTO_MAX)
1417         goto out;
1418 
1419     mutex_lock(&inet_diag_table_mutex);
1420     err = -EEXIST;
1421     if (!inet_diag_table[type]) {
1422         inet_diag_table[type] = h;
1423         err = 0;
1424     }
1425     mutex_unlock(&inet_diag_table_mutex);
1426 out:
1427     return err;
1428 }
1429 EXPORT_SYMBOL_GPL(inet_diag_register);
1430 
1431 void inet_diag_unregister(const struct inet_diag_handler *h)
1432 {
1433     const __u16 type = h->idiag_type;
1434 
1435     if (type >= IPPROTO_MAX)
1436         return;
1437 
1438     mutex_lock(&inet_diag_table_mutex);
1439     inet_diag_table[type] = NULL;
1440     mutex_unlock(&inet_diag_table_mutex);
1441 }
1442 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1443 
1444 static int __init inet_diag_init(void)
1445 {
1446     const int inet_diag_table_size = (IPPROTO_MAX *
1447                       sizeof(struct inet_diag_handler *));
1448     int err = -ENOMEM;
1449 
1450     inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1451     if (!inet_diag_table)
1452         goto out;
1453 
1454     err = sock_diag_register(&inet_diag_handler);
1455     if (err)
1456         goto out_free_nl;
1457 
1458     err = sock_diag_register(&inet6_diag_handler);
1459     if (err)
1460         goto out_free_inet;
1461 
1462     sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1463 out:
1464     return err;
1465 
1466 out_free_inet:
1467     sock_diag_unregister(&inet_diag_handler);
1468 out_free_nl:
1469     kfree(inet_diag_table);
1470     goto out;
1471 }
1472 
1473 static void __exit inet_diag_exit(void)
1474 {
1475     sock_diag_unregister(&inet6_diag_handler);
1476     sock_diag_unregister(&inet_diag_handler);
1477     sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1478     kfree(inet_diag_table);
1479 }
1480 
1481 module_init(inet_diag_init);
1482 module_exit(inet_diag_exit);
1483 MODULE_LICENSE("GPL");
1484 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1485 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);