Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Copyright (c) 2019 Facebook  */
0003 #include <linux/rculist.h>
0004 #include <linux/list.h>
0005 #include <linux/hash.h>
0006 #include <linux/types.h>
0007 #include <linux/spinlock.h>
0008 #include <linux/bpf.h>
0009 #include <linux/btf.h>
0010 #include <linux/btf_ids.h>
0011 #include <linux/bpf_local_storage.h>
0012 #include <net/bpf_sk_storage.h>
0013 #include <net/sock.h>
0014 #include <uapi/linux/sock_diag.h>
0015 #include <uapi/linux/btf.h>
0016 #include <linux/rcupdate_trace.h>
0017 
0018 DEFINE_BPF_STORAGE_CACHE(sk_cache);
0019 
0020 static struct bpf_local_storage_data *
0021 bpf_sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
0022 {
0023     struct bpf_local_storage *sk_storage;
0024     struct bpf_local_storage_map *smap;
0025 
0026     sk_storage =
0027         rcu_dereference_check(sk->sk_bpf_storage, bpf_rcu_lock_held());
0028     if (!sk_storage)
0029         return NULL;
0030 
0031     smap = (struct bpf_local_storage_map *)map;
0032     return bpf_local_storage_lookup(sk_storage, smap, cacheit_lockit);
0033 }
0034 
0035 static int bpf_sk_storage_del(struct sock *sk, struct bpf_map *map)
0036 {
0037     struct bpf_local_storage_data *sdata;
0038 
0039     sdata = bpf_sk_storage_lookup(sk, map, false);
0040     if (!sdata)
0041         return -ENOENT;
0042 
0043     bpf_selem_unlink(SELEM(sdata), true);
0044 
0045     return 0;
0046 }
0047 
0048 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
0049 void bpf_sk_storage_free(struct sock *sk)
0050 {
0051     struct bpf_local_storage_elem *selem;
0052     struct bpf_local_storage *sk_storage;
0053     bool free_sk_storage = false;
0054     struct hlist_node *n;
0055 
0056     rcu_read_lock();
0057     sk_storage = rcu_dereference(sk->sk_bpf_storage);
0058     if (!sk_storage) {
0059         rcu_read_unlock();
0060         return;
0061     }
0062 
0063     /* Netiher the bpf_prog nor the bpf-map's syscall
0064      * could be modifying the sk_storage->list now.
0065      * Thus, no elem can be added-to or deleted-from the
0066      * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
0067      *
0068      * It is racing with bpf_local_storage_map_free() alone
0069      * when unlinking elem from the sk_storage->list and
0070      * the map's bucket->list.
0071      */
0072     raw_spin_lock_bh(&sk_storage->lock);
0073     hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
0074         /* Always unlink from map before unlinking from
0075          * sk_storage.
0076          */
0077         bpf_selem_unlink_map(selem);
0078         free_sk_storage = bpf_selem_unlink_storage_nolock(
0079             sk_storage, selem, true, false);
0080     }
0081     raw_spin_unlock_bh(&sk_storage->lock);
0082     rcu_read_unlock();
0083 
0084     if (free_sk_storage)
0085         kfree_rcu(sk_storage, rcu);
0086 }
0087 
0088 static void bpf_sk_storage_map_free(struct bpf_map *map)
0089 {
0090     struct bpf_local_storage_map *smap;
0091 
0092     smap = (struct bpf_local_storage_map *)map;
0093     bpf_local_storage_cache_idx_free(&sk_cache, smap->cache_idx);
0094     bpf_local_storage_map_free(smap, NULL);
0095 }
0096 
0097 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
0098 {
0099     struct bpf_local_storage_map *smap;
0100 
0101     smap = bpf_local_storage_map_alloc(attr);
0102     if (IS_ERR(smap))
0103         return ERR_CAST(smap);
0104 
0105     smap->cache_idx = bpf_local_storage_cache_idx_get(&sk_cache);
0106     return &smap->map;
0107 }
0108 
0109 static int notsupp_get_next_key(struct bpf_map *map, void *key,
0110                 void *next_key)
0111 {
0112     return -ENOTSUPP;
0113 }
0114 
0115 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
0116 {
0117     struct bpf_local_storage_data *sdata;
0118     struct socket *sock;
0119     int fd, err;
0120 
0121     fd = *(int *)key;
0122     sock = sockfd_lookup(fd, &err);
0123     if (sock) {
0124         sdata = bpf_sk_storage_lookup(sock->sk, map, true);
0125         sockfd_put(sock);
0126         return sdata ? sdata->data : NULL;
0127     }
0128 
0129     return ERR_PTR(err);
0130 }
0131 
0132 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
0133                      void *value, u64 map_flags)
0134 {
0135     struct bpf_local_storage_data *sdata;
0136     struct socket *sock;
0137     int fd, err;
0138 
0139     fd = *(int *)key;
0140     sock = sockfd_lookup(fd, &err);
0141     if (sock) {
0142         sdata = bpf_local_storage_update(
0143             sock->sk, (struct bpf_local_storage_map *)map, value,
0144             map_flags, GFP_ATOMIC);
0145         sockfd_put(sock);
0146         return PTR_ERR_OR_ZERO(sdata);
0147     }
0148 
0149     return err;
0150 }
0151 
0152 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
0153 {
0154     struct socket *sock;
0155     int fd, err;
0156 
0157     fd = *(int *)key;
0158     sock = sockfd_lookup(fd, &err);
0159     if (sock) {
0160         err = bpf_sk_storage_del(sock->sk, map);
0161         sockfd_put(sock);
0162         return err;
0163     }
0164 
0165     return err;
0166 }
0167 
0168 static struct bpf_local_storage_elem *
0169 bpf_sk_storage_clone_elem(struct sock *newsk,
0170               struct bpf_local_storage_map *smap,
0171               struct bpf_local_storage_elem *selem)
0172 {
0173     struct bpf_local_storage_elem *copy_selem;
0174 
0175     copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
0176     if (!copy_selem)
0177         return NULL;
0178 
0179     if (map_value_has_spin_lock(&smap->map))
0180         copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
0181                       SDATA(selem)->data, true);
0182     else
0183         copy_map_value(&smap->map, SDATA(copy_selem)->data,
0184                    SDATA(selem)->data);
0185 
0186     return copy_selem;
0187 }
0188 
0189 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
0190 {
0191     struct bpf_local_storage *new_sk_storage = NULL;
0192     struct bpf_local_storage *sk_storage;
0193     struct bpf_local_storage_elem *selem;
0194     int ret = 0;
0195 
0196     RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
0197 
0198     rcu_read_lock();
0199     sk_storage = rcu_dereference(sk->sk_bpf_storage);
0200 
0201     if (!sk_storage || hlist_empty(&sk_storage->list))
0202         goto out;
0203 
0204     hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
0205         struct bpf_local_storage_elem *copy_selem;
0206         struct bpf_local_storage_map *smap;
0207         struct bpf_map *map;
0208 
0209         smap = rcu_dereference(SDATA(selem)->smap);
0210         if (!(smap->map.map_flags & BPF_F_CLONE))
0211             continue;
0212 
0213         /* Note that for lockless listeners adding new element
0214          * here can race with cleanup in bpf_local_storage_map_free.
0215          * Try to grab map refcnt to make sure that it's still
0216          * alive and prevent concurrent removal.
0217          */
0218         map = bpf_map_inc_not_zero(&smap->map);
0219         if (IS_ERR(map))
0220             continue;
0221 
0222         copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
0223         if (!copy_selem) {
0224             ret = -ENOMEM;
0225             bpf_map_put(map);
0226             goto out;
0227         }
0228 
0229         if (new_sk_storage) {
0230             bpf_selem_link_map(smap, copy_selem);
0231             bpf_selem_link_storage_nolock(new_sk_storage, copy_selem);
0232         } else {
0233             ret = bpf_local_storage_alloc(newsk, smap, copy_selem, GFP_ATOMIC);
0234             if (ret) {
0235                 kfree(copy_selem);
0236                 atomic_sub(smap->elem_size,
0237                        &newsk->sk_omem_alloc);
0238                 bpf_map_put(map);
0239                 goto out;
0240             }
0241 
0242             new_sk_storage =
0243                 rcu_dereference(copy_selem->local_storage);
0244         }
0245         bpf_map_put(map);
0246     }
0247 
0248 out:
0249     rcu_read_unlock();
0250 
0251     /* In case of an error, don't free anything explicitly here, the
0252      * caller is responsible to call bpf_sk_storage_free.
0253      */
0254 
0255     return ret;
0256 }
0257 
0258 /* *gfp_flags* is a hidden argument provided by the verifier */
0259 BPF_CALL_5(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
0260        void *, value, u64, flags, gfp_t, gfp_flags)
0261 {
0262     struct bpf_local_storage_data *sdata;
0263 
0264     WARN_ON_ONCE(!bpf_rcu_lock_held());
0265     if (!sk || !sk_fullsock(sk) || flags > BPF_SK_STORAGE_GET_F_CREATE)
0266         return (unsigned long)NULL;
0267 
0268     sdata = bpf_sk_storage_lookup(sk, map, true);
0269     if (sdata)
0270         return (unsigned long)sdata->data;
0271 
0272     if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
0273         /* Cannot add new elem to a going away sk.
0274          * Otherwise, the new elem may become a leak
0275          * (and also other memory issues during map
0276          *  destruction).
0277          */
0278         refcount_inc_not_zero(&sk->sk_refcnt)) {
0279         sdata = bpf_local_storage_update(
0280             sk, (struct bpf_local_storage_map *)map, value,
0281             BPF_NOEXIST, gfp_flags);
0282         /* sk must be a fullsock (guaranteed by verifier),
0283          * so sock_gen_put() is unnecessary.
0284          */
0285         sock_put(sk);
0286         return IS_ERR(sdata) ?
0287             (unsigned long)NULL : (unsigned long)sdata->data;
0288     }
0289 
0290     return (unsigned long)NULL;
0291 }
0292 
0293 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
0294 {
0295     WARN_ON_ONCE(!bpf_rcu_lock_held());
0296     if (!sk || !sk_fullsock(sk))
0297         return -EINVAL;
0298 
0299     if (refcount_inc_not_zero(&sk->sk_refcnt)) {
0300         int err;
0301 
0302         err = bpf_sk_storage_del(sk, map);
0303         sock_put(sk);
0304         return err;
0305     }
0306 
0307     return -ENOENT;
0308 }
0309 
0310 static int bpf_sk_storage_charge(struct bpf_local_storage_map *smap,
0311                  void *owner, u32 size)
0312 {
0313     int optmem_max = READ_ONCE(sysctl_optmem_max);
0314     struct sock *sk = (struct sock *)owner;
0315 
0316     /* same check as in sock_kmalloc() */
0317     if (size <= optmem_max &&
0318         atomic_read(&sk->sk_omem_alloc) + size < optmem_max) {
0319         atomic_add(size, &sk->sk_omem_alloc);
0320         return 0;
0321     }
0322 
0323     return -ENOMEM;
0324 }
0325 
0326 static void bpf_sk_storage_uncharge(struct bpf_local_storage_map *smap,
0327                     void *owner, u32 size)
0328 {
0329     struct sock *sk = owner;
0330 
0331     atomic_sub(size, &sk->sk_omem_alloc);
0332 }
0333 
0334 static struct bpf_local_storage __rcu **
0335 bpf_sk_storage_ptr(void *owner)
0336 {
0337     struct sock *sk = owner;
0338 
0339     return &sk->sk_bpf_storage;
0340 }
0341 
0342 BTF_ID_LIST_SINGLE(sk_storage_map_btf_ids, struct, bpf_local_storage_map)
0343 const struct bpf_map_ops sk_storage_map_ops = {
0344     .map_meta_equal = bpf_map_meta_equal,
0345     .map_alloc_check = bpf_local_storage_map_alloc_check,
0346     .map_alloc = bpf_sk_storage_map_alloc,
0347     .map_free = bpf_sk_storage_map_free,
0348     .map_get_next_key = notsupp_get_next_key,
0349     .map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
0350     .map_update_elem = bpf_fd_sk_storage_update_elem,
0351     .map_delete_elem = bpf_fd_sk_storage_delete_elem,
0352     .map_check_btf = bpf_local_storage_map_check_btf,
0353     .map_btf_id = &sk_storage_map_btf_ids[0],
0354     .map_local_storage_charge = bpf_sk_storage_charge,
0355     .map_local_storage_uncharge = bpf_sk_storage_uncharge,
0356     .map_owner_storage_ptr = bpf_sk_storage_ptr,
0357 };
0358 
0359 const struct bpf_func_proto bpf_sk_storage_get_proto = {
0360     .func       = bpf_sk_storage_get,
0361     .gpl_only   = false,
0362     .ret_type   = RET_PTR_TO_MAP_VALUE_OR_NULL,
0363     .arg1_type  = ARG_CONST_MAP_PTR,
0364     .arg2_type  = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
0365     .arg3_type  = ARG_PTR_TO_MAP_VALUE_OR_NULL,
0366     .arg4_type  = ARG_ANYTHING,
0367 };
0368 
0369 const struct bpf_func_proto bpf_sk_storage_get_cg_sock_proto = {
0370     .func       = bpf_sk_storage_get,
0371     .gpl_only   = false,
0372     .ret_type   = RET_PTR_TO_MAP_VALUE_OR_NULL,
0373     .arg1_type  = ARG_CONST_MAP_PTR,
0374     .arg2_type  = ARG_PTR_TO_CTX, /* context is 'struct sock' */
0375     .arg3_type  = ARG_PTR_TO_MAP_VALUE_OR_NULL,
0376     .arg4_type  = ARG_ANYTHING,
0377 };
0378 
0379 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
0380     .func       = bpf_sk_storage_delete,
0381     .gpl_only   = false,
0382     .ret_type   = RET_INTEGER,
0383     .arg1_type  = ARG_CONST_MAP_PTR,
0384     .arg2_type  = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
0385 };
0386 
0387 static bool bpf_sk_storage_tracing_allowed(const struct bpf_prog *prog)
0388 {
0389     const struct btf *btf_vmlinux;
0390     const struct btf_type *t;
0391     const char *tname;
0392     u32 btf_id;
0393 
0394     if (prog->aux->dst_prog)
0395         return false;
0396 
0397     /* Ensure the tracing program is not tracing
0398      * any bpf_sk_storage*() function and also
0399      * use the bpf_sk_storage_(get|delete) helper.
0400      */
0401     switch (prog->expected_attach_type) {
0402     case BPF_TRACE_ITER:
0403     case BPF_TRACE_RAW_TP:
0404         /* bpf_sk_storage has no trace point */
0405         return true;
0406     case BPF_TRACE_FENTRY:
0407     case BPF_TRACE_FEXIT:
0408         btf_vmlinux = bpf_get_btf_vmlinux();
0409         if (IS_ERR_OR_NULL(btf_vmlinux))
0410             return false;
0411         btf_id = prog->aux->attach_btf_id;
0412         t = btf_type_by_id(btf_vmlinux, btf_id);
0413         tname = btf_name_by_offset(btf_vmlinux, t->name_off);
0414         return !!strncmp(tname, "bpf_sk_storage",
0415                  strlen("bpf_sk_storage"));
0416     default:
0417         return false;
0418     }
0419 
0420     return false;
0421 }
0422 
0423 /* *gfp_flags* is a hidden argument provided by the verifier */
0424 BPF_CALL_5(bpf_sk_storage_get_tracing, struct bpf_map *, map, struct sock *, sk,
0425        void *, value, u64, flags, gfp_t, gfp_flags)
0426 {
0427     WARN_ON_ONCE(!bpf_rcu_lock_held());
0428     if (in_hardirq() || in_nmi())
0429         return (unsigned long)NULL;
0430 
0431     return (unsigned long)____bpf_sk_storage_get(map, sk, value, flags,
0432                              gfp_flags);
0433 }
0434 
0435 BPF_CALL_2(bpf_sk_storage_delete_tracing, struct bpf_map *, map,
0436        struct sock *, sk)
0437 {
0438     WARN_ON_ONCE(!bpf_rcu_lock_held());
0439     if (in_hardirq() || in_nmi())
0440         return -EPERM;
0441 
0442     return ____bpf_sk_storage_delete(map, sk);
0443 }
0444 
0445 const struct bpf_func_proto bpf_sk_storage_get_tracing_proto = {
0446     .func       = bpf_sk_storage_get_tracing,
0447     .gpl_only   = false,
0448     .ret_type   = RET_PTR_TO_MAP_VALUE_OR_NULL,
0449     .arg1_type  = ARG_CONST_MAP_PTR,
0450     .arg2_type  = ARG_PTR_TO_BTF_ID,
0451     .arg2_btf_id    = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
0452     .arg3_type  = ARG_PTR_TO_MAP_VALUE_OR_NULL,
0453     .arg4_type  = ARG_ANYTHING,
0454     .allowed    = bpf_sk_storage_tracing_allowed,
0455 };
0456 
0457 const struct bpf_func_proto bpf_sk_storage_delete_tracing_proto = {
0458     .func       = bpf_sk_storage_delete_tracing,
0459     .gpl_only   = false,
0460     .ret_type   = RET_INTEGER,
0461     .arg1_type  = ARG_CONST_MAP_PTR,
0462     .arg2_type  = ARG_PTR_TO_BTF_ID,
0463     .arg2_btf_id    = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
0464     .allowed    = bpf_sk_storage_tracing_allowed,
0465 };
0466 
0467 struct bpf_sk_storage_diag {
0468     u32 nr_maps;
0469     struct bpf_map *maps[];
0470 };
0471 
0472 /* The reply will be like:
0473  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
0474  *  SK_DIAG_BPF_STORAGE (nla_nest)
0475  *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
0476  *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
0477  *  SK_DIAG_BPF_STORAGE (nla_nest)
0478  *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
0479  *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
0480  *  ....
0481  */
0482 static int nla_value_size(u32 value_size)
0483 {
0484     /* SK_DIAG_BPF_STORAGE (nla_nest)
0485      *  SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
0486      *  SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
0487      */
0488     return nla_total_size(0) + nla_total_size(sizeof(u32)) +
0489         nla_total_size_64bit(value_size);
0490 }
0491 
0492 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
0493 {
0494     u32 i;
0495 
0496     if (!diag)
0497         return;
0498 
0499     for (i = 0; i < diag->nr_maps; i++)
0500         bpf_map_put(diag->maps[i]);
0501 
0502     kfree(diag);
0503 }
0504 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
0505 
0506 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
0507                const struct bpf_map *map)
0508 {
0509     u32 i;
0510 
0511     for (i = 0; i < diag->nr_maps; i++) {
0512         if (diag->maps[i] == map)
0513             return true;
0514     }
0515 
0516     return false;
0517 }
0518 
0519 struct bpf_sk_storage_diag *
0520 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
0521 {
0522     struct bpf_sk_storage_diag *diag;
0523     struct nlattr *nla;
0524     u32 nr_maps = 0;
0525     int rem, err;
0526 
0527     /* bpf_local_storage_map is currently limited to CAP_SYS_ADMIN as
0528      * the map_alloc_check() side also does.
0529      */
0530     if (!bpf_capable())
0531         return ERR_PTR(-EPERM);
0532 
0533     nla_for_each_nested(nla, nla_stgs, rem) {
0534         if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
0535             nr_maps++;
0536     }
0537 
0538     diag = kzalloc(struct_size(diag, maps, nr_maps), GFP_KERNEL);
0539     if (!diag)
0540         return ERR_PTR(-ENOMEM);
0541 
0542     nla_for_each_nested(nla, nla_stgs, rem) {
0543         struct bpf_map *map;
0544         int map_fd;
0545 
0546         if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
0547             continue;
0548 
0549         map_fd = nla_get_u32(nla);
0550         map = bpf_map_get(map_fd);
0551         if (IS_ERR(map)) {
0552             err = PTR_ERR(map);
0553             goto err_free;
0554         }
0555         if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
0556             bpf_map_put(map);
0557             err = -EINVAL;
0558             goto err_free;
0559         }
0560         if (diag_check_dup(diag, map)) {
0561             bpf_map_put(map);
0562             err = -EEXIST;
0563             goto err_free;
0564         }
0565         diag->maps[diag->nr_maps++] = map;
0566     }
0567 
0568     return diag;
0569 
0570 err_free:
0571     bpf_sk_storage_diag_free(diag);
0572     return ERR_PTR(err);
0573 }
0574 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
0575 
0576 static int diag_get(struct bpf_local_storage_data *sdata, struct sk_buff *skb)
0577 {
0578     struct nlattr *nla_stg, *nla_value;
0579     struct bpf_local_storage_map *smap;
0580 
0581     /* It cannot exceed max nlattr's payload */
0582     BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < BPF_LOCAL_STORAGE_MAX_VALUE_SIZE);
0583 
0584     nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
0585     if (!nla_stg)
0586         return -EMSGSIZE;
0587 
0588     smap = rcu_dereference(sdata->smap);
0589     if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
0590         goto errout;
0591 
0592     nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
0593                       smap->map.value_size,
0594                       SK_DIAG_BPF_STORAGE_PAD);
0595     if (!nla_value)
0596         goto errout;
0597 
0598     if (map_value_has_spin_lock(&smap->map))
0599         copy_map_value_locked(&smap->map, nla_data(nla_value),
0600                       sdata->data, true);
0601     else
0602         copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
0603 
0604     nla_nest_end(skb, nla_stg);
0605     return 0;
0606 
0607 errout:
0608     nla_nest_cancel(skb, nla_stg);
0609     return -EMSGSIZE;
0610 }
0611 
0612 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
0613                        int stg_array_type,
0614                        unsigned int *res_diag_size)
0615 {
0616     /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
0617     unsigned int diag_size = nla_total_size(0);
0618     struct bpf_local_storage *sk_storage;
0619     struct bpf_local_storage_elem *selem;
0620     struct bpf_local_storage_map *smap;
0621     struct nlattr *nla_stgs;
0622     unsigned int saved_len;
0623     int err = 0;
0624 
0625     rcu_read_lock();
0626 
0627     sk_storage = rcu_dereference(sk->sk_bpf_storage);
0628     if (!sk_storage || hlist_empty(&sk_storage->list)) {
0629         rcu_read_unlock();
0630         return 0;
0631     }
0632 
0633     nla_stgs = nla_nest_start(skb, stg_array_type);
0634     if (!nla_stgs)
0635         /* Continue to learn diag_size */
0636         err = -EMSGSIZE;
0637 
0638     saved_len = skb->len;
0639     hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
0640         smap = rcu_dereference(SDATA(selem)->smap);
0641         diag_size += nla_value_size(smap->map.value_size);
0642 
0643         if (nla_stgs && diag_get(SDATA(selem), skb))
0644             /* Continue to learn diag_size */
0645             err = -EMSGSIZE;
0646     }
0647 
0648     rcu_read_unlock();
0649 
0650     if (nla_stgs) {
0651         if (saved_len == skb->len)
0652             nla_nest_cancel(skb, nla_stgs);
0653         else
0654             nla_nest_end(skb, nla_stgs);
0655     }
0656 
0657     if (diag_size == nla_total_size(0)) {
0658         *res_diag_size = 0;
0659         return 0;
0660     }
0661 
0662     *res_diag_size = diag_size;
0663     return err;
0664 }
0665 
0666 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
0667                 struct sock *sk, struct sk_buff *skb,
0668                 int stg_array_type,
0669                 unsigned int *res_diag_size)
0670 {
0671     /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
0672     unsigned int diag_size = nla_total_size(0);
0673     struct bpf_local_storage *sk_storage;
0674     struct bpf_local_storage_data *sdata;
0675     struct nlattr *nla_stgs;
0676     unsigned int saved_len;
0677     int err = 0;
0678     u32 i;
0679 
0680     *res_diag_size = 0;
0681 
0682     /* No map has been specified.  Dump all. */
0683     if (!diag->nr_maps)
0684         return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
0685                            res_diag_size);
0686 
0687     rcu_read_lock();
0688     sk_storage = rcu_dereference(sk->sk_bpf_storage);
0689     if (!sk_storage || hlist_empty(&sk_storage->list)) {
0690         rcu_read_unlock();
0691         return 0;
0692     }
0693 
0694     nla_stgs = nla_nest_start(skb, stg_array_type);
0695     if (!nla_stgs)
0696         /* Continue to learn diag_size */
0697         err = -EMSGSIZE;
0698 
0699     saved_len = skb->len;
0700     for (i = 0; i < diag->nr_maps; i++) {
0701         sdata = bpf_local_storage_lookup(sk_storage,
0702                 (struct bpf_local_storage_map *)diag->maps[i],
0703                 false);
0704 
0705         if (!sdata)
0706             continue;
0707 
0708         diag_size += nla_value_size(diag->maps[i]->value_size);
0709 
0710         if (nla_stgs && diag_get(sdata, skb))
0711             /* Continue to learn diag_size */
0712             err = -EMSGSIZE;
0713     }
0714     rcu_read_unlock();
0715 
0716     if (nla_stgs) {
0717         if (saved_len == skb->len)
0718             nla_nest_cancel(skb, nla_stgs);
0719         else
0720             nla_nest_end(skb, nla_stgs);
0721     }
0722 
0723     if (diag_size == nla_total_size(0)) {
0724         *res_diag_size = 0;
0725         return 0;
0726     }
0727 
0728     *res_diag_size = diag_size;
0729     return err;
0730 }
0731 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
0732 
0733 struct bpf_iter_seq_sk_storage_map_info {
0734     struct bpf_map *map;
0735     unsigned int bucket_id;
0736     unsigned skip_elems;
0737 };
0738 
0739 static struct bpf_local_storage_elem *
0740 bpf_sk_storage_map_seq_find_next(struct bpf_iter_seq_sk_storage_map_info *info,
0741                  struct bpf_local_storage_elem *prev_selem)
0742     __acquires(RCU) __releases(RCU)
0743 {
0744     struct bpf_local_storage *sk_storage;
0745     struct bpf_local_storage_elem *selem;
0746     u32 skip_elems = info->skip_elems;
0747     struct bpf_local_storage_map *smap;
0748     u32 bucket_id = info->bucket_id;
0749     u32 i, count, n_buckets;
0750     struct bpf_local_storage_map_bucket *b;
0751 
0752     smap = (struct bpf_local_storage_map *)info->map;
0753     n_buckets = 1U << smap->bucket_log;
0754     if (bucket_id >= n_buckets)
0755         return NULL;
0756 
0757     /* try to find next selem in the same bucket */
0758     selem = prev_selem;
0759     count = 0;
0760     while (selem) {
0761         selem = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&selem->map_node)),
0762                      struct bpf_local_storage_elem, map_node);
0763         if (!selem) {
0764             /* not found, unlock and go to the next bucket */
0765             b = &smap->buckets[bucket_id++];
0766             rcu_read_unlock();
0767             skip_elems = 0;
0768             break;
0769         }
0770         sk_storage = rcu_dereference(selem->local_storage);
0771         if (sk_storage) {
0772             info->skip_elems = skip_elems + count;
0773             return selem;
0774         }
0775         count++;
0776     }
0777 
0778     for (i = bucket_id; i < (1U << smap->bucket_log); i++) {
0779         b = &smap->buckets[i];
0780         rcu_read_lock();
0781         count = 0;
0782         hlist_for_each_entry_rcu(selem, &b->list, map_node) {
0783             sk_storage = rcu_dereference(selem->local_storage);
0784             if (sk_storage && count >= skip_elems) {
0785                 info->bucket_id = i;
0786                 info->skip_elems = count;
0787                 return selem;
0788             }
0789             count++;
0790         }
0791         rcu_read_unlock();
0792         skip_elems = 0;
0793     }
0794 
0795     info->bucket_id = i;
0796     info->skip_elems = 0;
0797     return NULL;
0798 }
0799 
0800 static void *bpf_sk_storage_map_seq_start(struct seq_file *seq, loff_t *pos)
0801 {
0802     struct bpf_local_storage_elem *selem;
0803 
0804     selem = bpf_sk_storage_map_seq_find_next(seq->private, NULL);
0805     if (!selem)
0806         return NULL;
0807 
0808     if (*pos == 0)
0809         ++*pos;
0810     return selem;
0811 }
0812 
0813 static void *bpf_sk_storage_map_seq_next(struct seq_file *seq, void *v,
0814                      loff_t *pos)
0815 {
0816     struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
0817 
0818     ++*pos;
0819     ++info->skip_elems;
0820     return bpf_sk_storage_map_seq_find_next(seq->private, v);
0821 }
0822 
0823 struct bpf_iter__bpf_sk_storage_map {
0824     __bpf_md_ptr(struct bpf_iter_meta *, meta);
0825     __bpf_md_ptr(struct bpf_map *, map);
0826     __bpf_md_ptr(struct sock *, sk);
0827     __bpf_md_ptr(void *, value);
0828 };
0829 
0830 DEFINE_BPF_ITER_FUNC(bpf_sk_storage_map, struct bpf_iter_meta *meta,
0831              struct bpf_map *map, struct sock *sk,
0832              void *value)
0833 
0834 static int __bpf_sk_storage_map_seq_show(struct seq_file *seq,
0835                      struct bpf_local_storage_elem *selem)
0836 {
0837     struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
0838     struct bpf_iter__bpf_sk_storage_map ctx = {};
0839     struct bpf_local_storage *sk_storage;
0840     struct bpf_iter_meta meta;
0841     struct bpf_prog *prog;
0842     int ret = 0;
0843 
0844     meta.seq = seq;
0845     prog = bpf_iter_get_info(&meta, selem == NULL);
0846     if (prog) {
0847         ctx.meta = &meta;
0848         ctx.map = info->map;
0849         if (selem) {
0850             sk_storage = rcu_dereference(selem->local_storage);
0851             ctx.sk = sk_storage->owner;
0852             ctx.value = SDATA(selem)->data;
0853         }
0854         ret = bpf_iter_run_prog(prog, &ctx);
0855     }
0856 
0857     return ret;
0858 }
0859 
0860 static int bpf_sk_storage_map_seq_show(struct seq_file *seq, void *v)
0861 {
0862     return __bpf_sk_storage_map_seq_show(seq, v);
0863 }
0864 
0865 static void bpf_sk_storage_map_seq_stop(struct seq_file *seq, void *v)
0866     __releases(RCU)
0867 {
0868     if (!v)
0869         (void)__bpf_sk_storage_map_seq_show(seq, v);
0870     else
0871         rcu_read_unlock();
0872 }
0873 
0874 static int bpf_iter_init_sk_storage_map(void *priv_data,
0875                     struct bpf_iter_aux_info *aux)
0876 {
0877     struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data;
0878 
0879     bpf_map_inc_with_uref(aux->map);
0880     seq_info->map = aux->map;
0881     return 0;
0882 }
0883 
0884 static void bpf_iter_fini_sk_storage_map(void *priv_data)
0885 {
0886     struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data;
0887 
0888     bpf_map_put_with_uref(seq_info->map);
0889 }
0890 
0891 static int bpf_iter_attach_map(struct bpf_prog *prog,
0892                    union bpf_iter_link_info *linfo,
0893                    struct bpf_iter_aux_info *aux)
0894 {
0895     struct bpf_map *map;
0896     int err = -EINVAL;
0897 
0898     if (!linfo->map.map_fd)
0899         return -EBADF;
0900 
0901     map = bpf_map_get_with_uref(linfo->map.map_fd);
0902     if (IS_ERR(map))
0903         return PTR_ERR(map);
0904 
0905     if (map->map_type != BPF_MAP_TYPE_SK_STORAGE)
0906         goto put_map;
0907 
0908     if (prog->aux->max_rdwr_access > map->value_size) {
0909         err = -EACCES;
0910         goto put_map;
0911     }
0912 
0913     aux->map = map;
0914     return 0;
0915 
0916 put_map:
0917     bpf_map_put_with_uref(map);
0918     return err;
0919 }
0920 
0921 static void bpf_iter_detach_map(struct bpf_iter_aux_info *aux)
0922 {
0923     bpf_map_put_with_uref(aux->map);
0924 }
0925 
0926 static const struct seq_operations bpf_sk_storage_map_seq_ops = {
0927     .start  = bpf_sk_storage_map_seq_start,
0928     .next   = bpf_sk_storage_map_seq_next,
0929     .stop   = bpf_sk_storage_map_seq_stop,
0930     .show   = bpf_sk_storage_map_seq_show,
0931 };
0932 
0933 static const struct bpf_iter_seq_info iter_seq_info = {
0934     .seq_ops        = &bpf_sk_storage_map_seq_ops,
0935     .init_seq_private   = bpf_iter_init_sk_storage_map,
0936     .fini_seq_private   = bpf_iter_fini_sk_storage_map,
0937     .seq_priv_size      = sizeof(struct bpf_iter_seq_sk_storage_map_info),
0938 };
0939 
0940 static struct bpf_iter_reg bpf_sk_storage_map_reg_info = {
0941     .target         = "bpf_sk_storage_map",
0942     .attach_target      = bpf_iter_attach_map,
0943     .detach_target      = bpf_iter_detach_map,
0944     .show_fdinfo        = bpf_iter_map_show_fdinfo,
0945     .fill_link_info     = bpf_iter_map_fill_link_info,
0946     .ctx_arg_info_size  = 2,
0947     .ctx_arg_info       = {
0948         { offsetof(struct bpf_iter__bpf_sk_storage_map, sk),
0949           PTR_TO_BTF_ID_OR_NULL },
0950         { offsetof(struct bpf_iter__bpf_sk_storage_map, value),
0951           PTR_TO_BUF | PTR_MAYBE_NULL },
0952     },
0953     .seq_info       = &iter_seq_info,
0954 };
0955 
0956 static int __init bpf_sk_storage_map_iter_init(void)
0957 {
0958     bpf_sk_storage_map_reg_info.ctx_arg_info[0].btf_id =
0959         btf_sock_ids[BTF_SOCK_TYPE_SOCK];
0960     return bpf_iter_reg_target(&bpf_sk_storage_map_reg_info);
0961 }
0962 late_initcall(bpf_sk_storage_map_iter_init);