Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 #include <linux/module.h>
0003 
0004 #include <net/sock.h>
0005 #include <linux/netlink.h>
0006 #include <linux/sock_diag.h>
0007 #include <linux/netlink_diag.h>
0008 #include <linux/rhashtable.h>
0009 
0010 #include "af_netlink.h"
0011 
0012 static int sk_diag_dump_groups(struct sock *sk, struct sk_buff *nlskb)
0013 {
0014     struct netlink_sock *nlk = nlk_sk(sk);
0015 
0016     if (nlk->groups == NULL)
0017         return 0;
0018 
0019     return nla_put(nlskb, NETLINK_DIAG_GROUPS, NLGRPSZ(nlk->ngroups),
0020                nlk->groups);
0021 }
0022 
0023 static int sk_diag_put_flags(struct sock *sk, struct sk_buff *skb)
0024 {
0025     struct netlink_sock *nlk = nlk_sk(sk);
0026     u32 flags = 0;
0027 
0028     if (nlk->cb_running)
0029         flags |= NDIAG_FLAG_CB_RUNNING;
0030     if (nlk->flags & NETLINK_F_RECV_PKTINFO)
0031         flags |= NDIAG_FLAG_PKTINFO;
0032     if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
0033         flags |= NDIAG_FLAG_BROADCAST_ERROR;
0034     if (nlk->flags & NETLINK_F_RECV_NO_ENOBUFS)
0035         flags |= NDIAG_FLAG_NO_ENOBUFS;
0036     if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID)
0037         flags |= NDIAG_FLAG_LISTEN_ALL_NSID;
0038     if (nlk->flags & NETLINK_F_CAP_ACK)
0039         flags |= NDIAG_FLAG_CAP_ACK;
0040 
0041     return nla_put_u32(skb, NETLINK_DIAG_FLAGS, flags);
0042 }
0043 
0044 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
0045             struct netlink_diag_req *req,
0046             u32 portid, u32 seq, u32 flags, int sk_ino)
0047 {
0048     struct nlmsghdr *nlh;
0049     struct netlink_diag_msg *rep;
0050     struct netlink_sock *nlk = nlk_sk(sk);
0051 
0052     nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
0053             flags);
0054     if (!nlh)
0055         return -EMSGSIZE;
0056 
0057     rep = nlmsg_data(nlh);
0058     rep->ndiag_family   = AF_NETLINK;
0059     rep->ndiag_type     = sk->sk_type;
0060     rep->ndiag_protocol = sk->sk_protocol;
0061     rep->ndiag_state    = sk->sk_state;
0062 
0063     rep->ndiag_ino      = sk_ino;
0064     rep->ndiag_portid   = nlk->portid;
0065     rep->ndiag_dst_portid   = nlk->dst_portid;
0066     rep->ndiag_dst_group    = nlk->dst_group;
0067     sock_diag_save_cookie(sk, rep->ndiag_cookie);
0068 
0069     if ((req->ndiag_show & NDIAG_SHOW_GROUPS) &&
0070         sk_diag_dump_groups(sk, skb))
0071         goto out_nlmsg_trim;
0072 
0073     if ((req->ndiag_show & NDIAG_SHOW_MEMINFO) &&
0074         sock_diag_put_meminfo(sk, skb, NETLINK_DIAG_MEMINFO))
0075         goto out_nlmsg_trim;
0076 
0077     if ((req->ndiag_show & NDIAG_SHOW_FLAGS) &&
0078         sk_diag_put_flags(sk, skb))
0079         goto out_nlmsg_trim;
0080 
0081     nlmsg_end(skb, nlh);
0082     return 0;
0083 
0084 out_nlmsg_trim:
0085     nlmsg_cancel(skb, nlh);
0086     return -EMSGSIZE;
0087 }
0088 
0089 static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
0090                 int protocol, int s_num)
0091 {
0092     struct rhashtable_iter *hti = (void *)cb->args[2];
0093     struct netlink_table *tbl = &nl_table[protocol];
0094     struct net *net = sock_net(skb->sk);
0095     struct netlink_diag_req *req;
0096     struct netlink_sock *nlsk;
0097     struct sock *sk;
0098     int num = 2;
0099     int ret = 0;
0100 
0101     req = nlmsg_data(cb->nlh);
0102 
0103     if (s_num > 1)
0104         goto mc_list;
0105 
0106     num--;
0107 
0108     if (!hti) {
0109         hti = kmalloc(sizeof(*hti), GFP_KERNEL);
0110         if (!hti)
0111             return -ENOMEM;
0112 
0113         cb->args[2] = (long)hti;
0114     }
0115 
0116     if (!s_num)
0117         rhashtable_walk_enter(&tbl->hash, hti);
0118 
0119     rhashtable_walk_start(hti);
0120 
0121     while ((nlsk = rhashtable_walk_next(hti))) {
0122         if (IS_ERR(nlsk)) {
0123             ret = PTR_ERR(nlsk);
0124             if (ret == -EAGAIN) {
0125                 ret = 0;
0126                 continue;
0127             }
0128             break;
0129         }
0130 
0131         sk = (struct sock *)nlsk;
0132 
0133         if (!net_eq(sock_net(sk), net))
0134             continue;
0135 
0136         if (sk_diag_fill(sk, skb, req,
0137                  NETLINK_CB(cb->skb).portid,
0138                  cb->nlh->nlmsg_seq,
0139                  NLM_F_MULTI,
0140                  sock_i_ino(sk)) < 0) {
0141             ret = 1;
0142             break;
0143         }
0144     }
0145 
0146     rhashtable_walk_stop(hti);
0147 
0148     if (ret)
0149         goto done;
0150 
0151     rhashtable_walk_exit(hti);
0152     num++;
0153 
0154 mc_list:
0155     read_lock(&nl_table_lock);
0156     sk_for_each_bound(sk, &tbl->mc_list) {
0157         if (sk_hashed(sk))
0158             continue;
0159         if (!net_eq(sock_net(sk), net))
0160             continue;
0161         if (num < s_num) {
0162             num++;
0163             continue;
0164         }
0165 
0166         if (sk_diag_fill(sk, skb, req,
0167                  NETLINK_CB(cb->skb).portid,
0168                  cb->nlh->nlmsg_seq,
0169                  NLM_F_MULTI,
0170                  sock_i_ino(sk)) < 0) {
0171             ret = 1;
0172             break;
0173         }
0174         num++;
0175     }
0176     read_unlock(&nl_table_lock);
0177 
0178 done:
0179     cb->args[0] = num;
0180 
0181     return ret;
0182 }
0183 
0184 static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
0185 {
0186     struct netlink_diag_req *req;
0187     int s_num = cb->args[0];
0188     int err = 0;
0189 
0190     req = nlmsg_data(cb->nlh);
0191 
0192     if (req->sdiag_protocol == NDIAG_PROTO_ALL) {
0193         int i;
0194 
0195         for (i = cb->args[1]; i < MAX_LINKS; i++) {
0196             err = __netlink_diag_dump(skb, cb, i, s_num);
0197             if (err)
0198                 break;
0199             s_num = 0;
0200         }
0201         cb->args[1] = i;
0202     } else {
0203         if (req->sdiag_protocol >= MAX_LINKS)
0204             return -ENOENT;
0205 
0206         err = __netlink_diag_dump(skb, cb, req->sdiag_protocol, s_num);
0207     }
0208 
0209     return err < 0 ? err : skb->len;
0210 }
0211 
0212 static int netlink_diag_dump_done(struct netlink_callback *cb)
0213 {
0214     struct rhashtable_iter *hti = (void *)cb->args[2];
0215 
0216     if (cb->args[0] == 1)
0217         rhashtable_walk_exit(hti);
0218 
0219     kfree(hti);
0220 
0221     return 0;
0222 }
0223 
0224 static int netlink_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
0225 {
0226     int hdrlen = sizeof(struct netlink_diag_req);
0227     struct net *net = sock_net(skb->sk);
0228 
0229     if (nlmsg_len(h) < hdrlen)
0230         return -EINVAL;
0231 
0232     if (h->nlmsg_flags & NLM_F_DUMP) {
0233         struct netlink_dump_control c = {
0234             .dump = netlink_diag_dump,
0235             .done = netlink_diag_dump_done,
0236         };
0237         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
0238     } else
0239         return -EOPNOTSUPP;
0240 }
0241 
0242 static const struct sock_diag_handler netlink_diag_handler = {
0243     .family = AF_NETLINK,
0244     .dump = netlink_diag_handler_dump,
0245 };
0246 
0247 static int __init netlink_diag_init(void)
0248 {
0249     return sock_diag_register(&netlink_diag_handler);
0250 }
0251 
0252 static void __exit netlink_diag_exit(void)
0253 {
0254     sock_diag_unregister(&netlink_diag_handler);
0255 }
0256 
0257 module_init(netlink_diag_init);
0258 module_exit(netlink_diag_exit);
0259 MODULE_LICENSE("GPL");
0260 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 16 /* AF_NETLINK */);