Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
0003 
0004 #include <linux/bpf.h>
0005 #include <linux/btf_ids.h>
0006 #include <linux/filter.h>
0007 #include <linux/errno.h>
0008 #include <linux/file.h>
0009 #include <linux/net.h>
0010 #include <linux/workqueue.h>
0011 #include <linux/skmsg.h>
0012 #include <linux/list.h>
0013 #include <linux/jhash.h>
0014 #include <linux/sock_diag.h>
0015 #include <net/udp.h>
0016 
0017 struct bpf_stab {
0018     struct bpf_map map;
0019     struct sock **sks;
0020     struct sk_psock_progs progs;
0021     raw_spinlock_t lock;
0022 };
0023 
0024 #define SOCK_CREATE_FLAG_MASK               \
0025     (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
0026 
0027 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
0028                 struct bpf_prog *old, u32 which);
0029 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map);
0030 
0031 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
0032 {
0033     struct bpf_stab *stab;
0034 
0035     if (!capable(CAP_NET_ADMIN))
0036         return ERR_PTR(-EPERM);
0037     if (attr->max_entries == 0 ||
0038         attr->key_size    != 4 ||
0039         (attr->value_size != sizeof(u32) &&
0040          attr->value_size != sizeof(u64)) ||
0041         attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
0042         return ERR_PTR(-EINVAL);
0043 
0044     stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT);
0045     if (!stab)
0046         return ERR_PTR(-ENOMEM);
0047 
0048     bpf_map_init_from_attr(&stab->map, attr);
0049     raw_spin_lock_init(&stab->lock);
0050 
0051     stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries *
0052                        sizeof(struct sock *),
0053                        stab->map.numa_node);
0054     if (!stab->sks) {
0055         kfree(stab);
0056         return ERR_PTR(-ENOMEM);
0057     }
0058 
0059     return &stab->map;
0060 }
0061 
0062 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
0063 {
0064     u32 ufd = attr->target_fd;
0065     struct bpf_map *map;
0066     struct fd f;
0067     int ret;
0068 
0069     if (attr->attach_flags || attr->replace_bpf_fd)
0070         return -EINVAL;
0071 
0072     f = fdget(ufd);
0073     map = __bpf_map_get(f);
0074     if (IS_ERR(map))
0075         return PTR_ERR(map);
0076     ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
0077     fdput(f);
0078     return ret;
0079 }
0080 
0081 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
0082 {
0083     u32 ufd = attr->target_fd;
0084     struct bpf_prog *prog;
0085     struct bpf_map *map;
0086     struct fd f;
0087     int ret;
0088 
0089     if (attr->attach_flags || attr->replace_bpf_fd)
0090         return -EINVAL;
0091 
0092     f = fdget(ufd);
0093     map = __bpf_map_get(f);
0094     if (IS_ERR(map))
0095         return PTR_ERR(map);
0096 
0097     prog = bpf_prog_get(attr->attach_bpf_fd);
0098     if (IS_ERR(prog)) {
0099         ret = PTR_ERR(prog);
0100         goto put_map;
0101     }
0102 
0103     if (prog->type != ptype) {
0104         ret = -EINVAL;
0105         goto put_prog;
0106     }
0107 
0108     ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
0109 put_prog:
0110     bpf_prog_put(prog);
0111 put_map:
0112     fdput(f);
0113     return ret;
0114 }
0115 
0116 static void sock_map_sk_acquire(struct sock *sk)
0117     __acquires(&sk->sk_lock.slock)
0118 {
0119     lock_sock(sk);
0120     preempt_disable();
0121     rcu_read_lock();
0122 }
0123 
0124 static void sock_map_sk_release(struct sock *sk)
0125     __releases(&sk->sk_lock.slock)
0126 {
0127     rcu_read_unlock();
0128     preempt_enable();
0129     release_sock(sk);
0130 }
0131 
0132 static void sock_map_add_link(struct sk_psock *psock,
0133                   struct sk_psock_link *link,
0134                   struct bpf_map *map, void *link_raw)
0135 {
0136     link->link_raw = link_raw;
0137     link->map = map;
0138     spin_lock_bh(&psock->link_lock);
0139     list_add_tail(&link->list, &psock->link);
0140     spin_unlock_bh(&psock->link_lock);
0141 }
0142 
0143 static void sock_map_del_link(struct sock *sk,
0144                   struct sk_psock *psock, void *link_raw)
0145 {
0146     bool strp_stop = false, verdict_stop = false;
0147     struct sk_psock_link *link, *tmp;
0148 
0149     spin_lock_bh(&psock->link_lock);
0150     list_for_each_entry_safe(link, tmp, &psock->link, list) {
0151         if (link->link_raw == link_raw) {
0152             struct bpf_map *map = link->map;
0153             struct bpf_stab *stab = container_of(map, struct bpf_stab,
0154                                  map);
0155             if (psock->saved_data_ready && stab->progs.stream_parser)
0156                 strp_stop = true;
0157             if (psock->saved_data_ready && stab->progs.stream_verdict)
0158                 verdict_stop = true;
0159             if (psock->saved_data_ready && stab->progs.skb_verdict)
0160                 verdict_stop = true;
0161             list_del(&link->list);
0162             sk_psock_free_link(link);
0163         }
0164     }
0165     spin_unlock_bh(&psock->link_lock);
0166     if (strp_stop || verdict_stop) {
0167         write_lock_bh(&sk->sk_callback_lock);
0168         if (strp_stop)
0169             sk_psock_stop_strp(sk, psock);
0170         if (verdict_stop)
0171             sk_psock_stop_verdict(sk, psock);
0172 
0173         if (psock->psock_update_sk_prot)
0174             psock->psock_update_sk_prot(sk, psock, false);
0175         write_unlock_bh(&sk->sk_callback_lock);
0176     }
0177 }
0178 
0179 static void sock_map_unref(struct sock *sk, void *link_raw)
0180 {
0181     struct sk_psock *psock = sk_psock(sk);
0182 
0183     if (likely(psock)) {
0184         sock_map_del_link(sk, psock, link_raw);
0185         sk_psock_put(sk, psock);
0186     }
0187 }
0188 
0189 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
0190 {
0191     if (!sk->sk_prot->psock_update_sk_prot)
0192         return -EINVAL;
0193     psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
0194     return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
0195 }
0196 
0197 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
0198 {
0199     struct sk_psock *psock;
0200 
0201     rcu_read_lock();
0202     psock = sk_psock(sk);
0203     if (psock) {
0204         if (sk->sk_prot->close != sock_map_close) {
0205             psock = ERR_PTR(-EBUSY);
0206             goto out;
0207         }
0208 
0209         if (!refcount_inc_not_zero(&psock->refcnt))
0210             psock = ERR_PTR(-EBUSY);
0211     }
0212 out:
0213     rcu_read_unlock();
0214     return psock;
0215 }
0216 
0217 static int sock_map_link(struct bpf_map *map, struct sock *sk)
0218 {
0219     struct sk_psock_progs *progs = sock_map_progs(map);
0220     struct bpf_prog *stream_verdict = NULL;
0221     struct bpf_prog *stream_parser = NULL;
0222     struct bpf_prog *skb_verdict = NULL;
0223     struct bpf_prog *msg_parser = NULL;
0224     struct sk_psock *psock;
0225     int ret;
0226 
0227     stream_verdict = READ_ONCE(progs->stream_verdict);
0228     if (stream_verdict) {
0229         stream_verdict = bpf_prog_inc_not_zero(stream_verdict);
0230         if (IS_ERR(stream_verdict))
0231             return PTR_ERR(stream_verdict);
0232     }
0233 
0234     stream_parser = READ_ONCE(progs->stream_parser);
0235     if (stream_parser) {
0236         stream_parser = bpf_prog_inc_not_zero(stream_parser);
0237         if (IS_ERR(stream_parser)) {
0238             ret = PTR_ERR(stream_parser);
0239             goto out_put_stream_verdict;
0240         }
0241     }
0242 
0243     msg_parser = READ_ONCE(progs->msg_parser);
0244     if (msg_parser) {
0245         msg_parser = bpf_prog_inc_not_zero(msg_parser);
0246         if (IS_ERR(msg_parser)) {
0247             ret = PTR_ERR(msg_parser);
0248             goto out_put_stream_parser;
0249         }
0250     }
0251 
0252     skb_verdict = READ_ONCE(progs->skb_verdict);
0253     if (skb_verdict) {
0254         skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
0255         if (IS_ERR(skb_verdict)) {
0256             ret = PTR_ERR(skb_verdict);
0257             goto out_put_msg_parser;
0258         }
0259     }
0260 
0261     psock = sock_map_psock_get_checked(sk);
0262     if (IS_ERR(psock)) {
0263         ret = PTR_ERR(psock);
0264         goto out_progs;
0265     }
0266 
0267     if (psock) {
0268         if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
0269             (stream_parser  && READ_ONCE(psock->progs.stream_parser)) ||
0270             (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
0271             (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) ||
0272             (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
0273             (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) {
0274             sk_psock_put(sk, psock);
0275             ret = -EBUSY;
0276             goto out_progs;
0277         }
0278     } else {
0279         psock = sk_psock_init(sk, map->numa_node);
0280         if (IS_ERR(psock)) {
0281             ret = PTR_ERR(psock);
0282             goto out_progs;
0283         }
0284     }
0285 
0286     if (msg_parser)
0287         psock_set_prog(&psock->progs.msg_parser, msg_parser);
0288     if (stream_parser)
0289         psock_set_prog(&psock->progs.stream_parser, stream_parser);
0290     if (stream_verdict)
0291         psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
0292     if (skb_verdict)
0293         psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
0294 
0295     /* msg_* and stream_* programs references tracked in psock after this
0296      * point. Reference dec and cleanup will occur through psock destructor
0297      */
0298     ret = sock_map_init_proto(sk, psock);
0299     if (ret < 0) {
0300         sk_psock_put(sk, psock);
0301         goto out;
0302     }
0303 
0304     write_lock_bh(&sk->sk_callback_lock);
0305     if (stream_parser && stream_verdict && !psock->saved_data_ready) {
0306         ret = sk_psock_init_strp(sk, psock);
0307         if (ret) {
0308             write_unlock_bh(&sk->sk_callback_lock);
0309             sk_psock_put(sk, psock);
0310             goto out;
0311         }
0312         sk_psock_start_strp(sk, psock);
0313     } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) {
0314         sk_psock_start_verdict(sk,psock);
0315     } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) {
0316         sk_psock_start_verdict(sk, psock);
0317     }
0318     write_unlock_bh(&sk->sk_callback_lock);
0319     return 0;
0320 out_progs:
0321     if (skb_verdict)
0322         bpf_prog_put(skb_verdict);
0323 out_put_msg_parser:
0324     if (msg_parser)
0325         bpf_prog_put(msg_parser);
0326 out_put_stream_parser:
0327     if (stream_parser)
0328         bpf_prog_put(stream_parser);
0329 out_put_stream_verdict:
0330     if (stream_verdict)
0331         bpf_prog_put(stream_verdict);
0332 out:
0333     return ret;
0334 }
0335 
0336 static void sock_map_free(struct bpf_map *map)
0337 {
0338     struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
0339     int i;
0340 
0341     /* After the sync no updates or deletes will be in-flight so it
0342      * is safe to walk map and remove entries without risking a race
0343      * in EEXIST update case.
0344      */
0345     synchronize_rcu();
0346     for (i = 0; i < stab->map.max_entries; i++) {
0347         struct sock **psk = &stab->sks[i];
0348         struct sock *sk;
0349 
0350         sk = xchg(psk, NULL);
0351         if (sk) {
0352             lock_sock(sk);
0353             rcu_read_lock();
0354             sock_map_unref(sk, psk);
0355             rcu_read_unlock();
0356             release_sock(sk);
0357         }
0358     }
0359 
0360     /* wait for psock readers accessing its map link */
0361     synchronize_rcu();
0362 
0363     bpf_map_area_free(stab->sks);
0364     kfree(stab);
0365 }
0366 
0367 static void sock_map_release_progs(struct bpf_map *map)
0368 {
0369     psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
0370 }
0371 
0372 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
0373 {
0374     struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
0375 
0376     WARN_ON_ONCE(!rcu_read_lock_held());
0377 
0378     if (unlikely(key >= map->max_entries))
0379         return NULL;
0380     return READ_ONCE(stab->sks[key]);
0381 }
0382 
0383 static void *sock_map_lookup(struct bpf_map *map, void *key)
0384 {
0385     struct sock *sk;
0386 
0387     sk = __sock_map_lookup_elem(map, *(u32 *)key);
0388     if (!sk)
0389         return NULL;
0390     if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
0391         return NULL;
0392     return sk;
0393 }
0394 
0395 static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
0396 {
0397     struct sock *sk;
0398 
0399     if (map->value_size != sizeof(u64))
0400         return ERR_PTR(-ENOSPC);
0401 
0402     sk = __sock_map_lookup_elem(map, *(u32 *)key);
0403     if (!sk)
0404         return ERR_PTR(-ENOENT);
0405 
0406     __sock_gen_cookie(sk);
0407     return &sk->sk_cookie;
0408 }
0409 
0410 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
0411                  struct sock **psk)
0412 {
0413     struct sock *sk;
0414     int err = 0;
0415 
0416     raw_spin_lock_bh(&stab->lock);
0417     sk = *psk;
0418     if (!sk_test || sk_test == sk)
0419         sk = xchg(psk, NULL);
0420 
0421     if (likely(sk))
0422         sock_map_unref(sk, psk);
0423     else
0424         err = -EINVAL;
0425 
0426     raw_spin_unlock_bh(&stab->lock);
0427     return err;
0428 }
0429 
0430 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
0431                       void *link_raw)
0432 {
0433     struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
0434 
0435     __sock_map_delete(stab, sk, link_raw);
0436 }
0437 
0438 static int sock_map_delete_elem(struct bpf_map *map, void *key)
0439 {
0440     struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
0441     u32 i = *(u32 *)key;
0442     struct sock **psk;
0443 
0444     if (unlikely(i >= map->max_entries))
0445         return -EINVAL;
0446 
0447     psk = &stab->sks[i];
0448     return __sock_map_delete(stab, NULL, psk);
0449 }
0450 
0451 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
0452 {
0453     struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
0454     u32 i = key ? *(u32 *)key : U32_MAX;
0455     u32 *key_next = next;
0456 
0457     if (i == stab->map.max_entries - 1)
0458         return -ENOENT;
0459     if (i >= stab->map.max_entries)
0460         *key_next = 0;
0461     else
0462         *key_next = i + 1;
0463     return 0;
0464 }
0465 
0466 static int sock_map_update_common(struct bpf_map *map, u32 idx,
0467                   struct sock *sk, u64 flags)
0468 {
0469     struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
0470     struct sk_psock_link *link;
0471     struct sk_psock *psock;
0472     struct sock *osk;
0473     int ret;
0474 
0475     WARN_ON_ONCE(!rcu_read_lock_held());
0476     if (unlikely(flags > BPF_EXIST))
0477         return -EINVAL;
0478     if (unlikely(idx >= map->max_entries))
0479         return -E2BIG;
0480 
0481     link = sk_psock_init_link();
0482     if (!link)
0483         return -ENOMEM;
0484 
0485     ret = sock_map_link(map, sk);
0486     if (ret < 0)
0487         goto out_free;
0488 
0489     psock = sk_psock(sk);
0490     WARN_ON_ONCE(!psock);
0491 
0492     raw_spin_lock_bh(&stab->lock);
0493     osk = stab->sks[idx];
0494     if (osk && flags == BPF_NOEXIST) {
0495         ret = -EEXIST;
0496         goto out_unlock;
0497     } else if (!osk && flags == BPF_EXIST) {
0498         ret = -ENOENT;
0499         goto out_unlock;
0500     }
0501 
0502     sock_map_add_link(psock, link, map, &stab->sks[idx]);
0503     stab->sks[idx] = sk;
0504     if (osk)
0505         sock_map_unref(osk, &stab->sks[idx]);
0506     raw_spin_unlock_bh(&stab->lock);
0507     return 0;
0508 out_unlock:
0509     raw_spin_unlock_bh(&stab->lock);
0510     if (psock)
0511         sk_psock_put(sk, psock);
0512 out_free:
0513     sk_psock_free_link(link);
0514     return ret;
0515 }
0516 
0517 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
0518 {
0519     return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
0520            ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
0521            ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
0522 }
0523 
0524 static bool sock_map_redirect_allowed(const struct sock *sk)
0525 {
0526     if (sk_is_tcp(sk))
0527         return sk->sk_state != TCP_LISTEN;
0528     else
0529         return sk->sk_state == TCP_ESTABLISHED;
0530 }
0531 
0532 static bool sock_map_sk_is_suitable(const struct sock *sk)
0533 {
0534     return !!sk->sk_prot->psock_update_sk_prot;
0535 }
0536 
0537 static bool sock_map_sk_state_allowed(const struct sock *sk)
0538 {
0539     if (sk_is_tcp(sk))
0540         return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
0541     return true;
0542 }
0543 
0544 static int sock_hash_update_common(struct bpf_map *map, void *key,
0545                    struct sock *sk, u64 flags);
0546 
0547 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value,
0548                  u64 flags)
0549 {
0550     struct socket *sock;
0551     struct sock *sk;
0552     int ret;
0553     u64 ufd;
0554 
0555     if (map->value_size == sizeof(u64))
0556         ufd = *(u64 *)value;
0557     else
0558         ufd = *(u32 *)value;
0559     if (ufd > S32_MAX)
0560         return -EINVAL;
0561 
0562     sock = sockfd_lookup(ufd, &ret);
0563     if (!sock)
0564         return ret;
0565     sk = sock->sk;
0566     if (!sk) {
0567         ret = -EINVAL;
0568         goto out;
0569     }
0570     if (!sock_map_sk_is_suitable(sk)) {
0571         ret = -EOPNOTSUPP;
0572         goto out;
0573     }
0574 
0575     sock_map_sk_acquire(sk);
0576     if (!sock_map_sk_state_allowed(sk))
0577         ret = -EOPNOTSUPP;
0578     else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
0579         ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
0580     else
0581         ret = sock_hash_update_common(map, key, sk, flags);
0582     sock_map_sk_release(sk);
0583 out:
0584     sockfd_put(sock);
0585     return ret;
0586 }
0587 
0588 static int sock_map_update_elem(struct bpf_map *map, void *key,
0589                 void *value, u64 flags)
0590 {
0591     struct sock *sk = (struct sock *)value;
0592     int ret;
0593 
0594     if (unlikely(!sk || !sk_fullsock(sk)))
0595         return -EINVAL;
0596 
0597     if (!sock_map_sk_is_suitable(sk))
0598         return -EOPNOTSUPP;
0599 
0600     local_bh_disable();
0601     bh_lock_sock(sk);
0602     if (!sock_map_sk_state_allowed(sk))
0603         ret = -EOPNOTSUPP;
0604     else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
0605         ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
0606     else
0607         ret = sock_hash_update_common(map, key, sk, flags);
0608     bh_unlock_sock(sk);
0609     local_bh_enable();
0610     return ret;
0611 }
0612 
0613 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
0614        struct bpf_map *, map, void *, key, u64, flags)
0615 {
0616     WARN_ON_ONCE(!rcu_read_lock_held());
0617 
0618     if (likely(sock_map_sk_is_suitable(sops->sk) &&
0619            sock_map_op_okay(sops)))
0620         return sock_map_update_common(map, *(u32 *)key, sops->sk,
0621                           flags);
0622     return -EOPNOTSUPP;
0623 }
0624 
0625 const struct bpf_func_proto bpf_sock_map_update_proto = {
0626     .func       = bpf_sock_map_update,
0627     .gpl_only   = false,
0628     .pkt_access = true,
0629     .ret_type   = RET_INTEGER,
0630     .arg1_type  = ARG_PTR_TO_CTX,
0631     .arg2_type  = ARG_CONST_MAP_PTR,
0632     .arg3_type  = ARG_PTR_TO_MAP_KEY,
0633     .arg4_type  = ARG_ANYTHING,
0634 };
0635 
0636 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
0637        struct bpf_map *, map, u32, key, u64, flags)
0638 {
0639     struct sock *sk;
0640 
0641     if (unlikely(flags & ~(BPF_F_INGRESS)))
0642         return SK_DROP;
0643 
0644     sk = __sock_map_lookup_elem(map, key);
0645     if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
0646         return SK_DROP;
0647 
0648     skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
0649     return SK_PASS;
0650 }
0651 
0652 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
0653     .func           = bpf_sk_redirect_map,
0654     .gpl_only       = false,
0655     .ret_type       = RET_INTEGER,
0656     .arg1_type  = ARG_PTR_TO_CTX,
0657     .arg2_type      = ARG_CONST_MAP_PTR,
0658     .arg3_type      = ARG_ANYTHING,
0659     .arg4_type      = ARG_ANYTHING,
0660 };
0661 
0662 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
0663        struct bpf_map *, map, u32, key, u64, flags)
0664 {
0665     struct sock *sk;
0666 
0667     if (unlikely(flags & ~(BPF_F_INGRESS)))
0668         return SK_DROP;
0669 
0670     sk = __sock_map_lookup_elem(map, key);
0671     if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
0672         return SK_DROP;
0673 
0674     msg->flags = flags;
0675     msg->sk_redir = sk;
0676     return SK_PASS;
0677 }
0678 
0679 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
0680     .func           = bpf_msg_redirect_map,
0681     .gpl_only       = false,
0682     .ret_type       = RET_INTEGER,
0683     .arg1_type  = ARG_PTR_TO_CTX,
0684     .arg2_type      = ARG_CONST_MAP_PTR,
0685     .arg3_type      = ARG_ANYTHING,
0686     .arg4_type      = ARG_ANYTHING,
0687 };
0688 
0689 struct sock_map_seq_info {
0690     struct bpf_map *map;
0691     struct sock *sk;
0692     u32 index;
0693 };
0694 
0695 struct bpf_iter__sockmap {
0696     __bpf_md_ptr(struct bpf_iter_meta *, meta);
0697     __bpf_md_ptr(struct bpf_map *, map);
0698     __bpf_md_ptr(void *, key);
0699     __bpf_md_ptr(struct sock *, sk);
0700 };
0701 
0702 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta,
0703              struct bpf_map *map, void *key,
0704              struct sock *sk)
0705 
0706 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info)
0707 {
0708     if (unlikely(info->index >= info->map->max_entries))
0709         return NULL;
0710 
0711     info->sk = __sock_map_lookup_elem(info->map, info->index);
0712 
0713     /* can't return sk directly, since that might be NULL */
0714     return info;
0715 }
0716 
0717 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos)
0718     __acquires(rcu)
0719 {
0720     struct sock_map_seq_info *info = seq->private;
0721 
0722     if (*pos == 0)
0723         ++*pos;
0724 
0725     /* pairs with sock_map_seq_stop */
0726     rcu_read_lock();
0727     return sock_map_seq_lookup_elem(info);
0728 }
0729 
0730 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos)
0731     __must_hold(rcu)
0732 {
0733     struct sock_map_seq_info *info = seq->private;
0734 
0735     ++*pos;
0736     ++info->index;
0737 
0738     return sock_map_seq_lookup_elem(info);
0739 }
0740 
0741 static int sock_map_seq_show(struct seq_file *seq, void *v)
0742     __must_hold(rcu)
0743 {
0744     struct sock_map_seq_info *info = seq->private;
0745     struct bpf_iter__sockmap ctx = {};
0746     struct bpf_iter_meta meta;
0747     struct bpf_prog *prog;
0748 
0749     meta.seq = seq;
0750     prog = bpf_iter_get_info(&meta, !v);
0751     if (!prog)
0752         return 0;
0753 
0754     ctx.meta = &meta;
0755     ctx.map = info->map;
0756     if (v) {
0757         ctx.key = &info->index;
0758         ctx.sk = info->sk;
0759     }
0760 
0761     return bpf_iter_run_prog(prog, &ctx);
0762 }
0763 
0764 static void sock_map_seq_stop(struct seq_file *seq, void *v)
0765     __releases(rcu)
0766 {
0767     if (!v)
0768         (void)sock_map_seq_show(seq, NULL);
0769 
0770     /* pairs with sock_map_seq_start */
0771     rcu_read_unlock();
0772 }
0773 
0774 static const struct seq_operations sock_map_seq_ops = {
0775     .start  = sock_map_seq_start,
0776     .next   = sock_map_seq_next,
0777     .stop   = sock_map_seq_stop,
0778     .show   = sock_map_seq_show,
0779 };
0780 
0781 static int sock_map_init_seq_private(void *priv_data,
0782                      struct bpf_iter_aux_info *aux)
0783 {
0784     struct sock_map_seq_info *info = priv_data;
0785 
0786     bpf_map_inc_with_uref(aux->map);
0787     info->map = aux->map;
0788     return 0;
0789 }
0790 
0791 static void sock_map_fini_seq_private(void *priv_data)
0792 {
0793     struct sock_map_seq_info *info = priv_data;
0794 
0795     bpf_map_put_with_uref(info->map);
0796 }
0797 
0798 static const struct bpf_iter_seq_info sock_map_iter_seq_info = {
0799     .seq_ops        = &sock_map_seq_ops,
0800     .init_seq_private   = sock_map_init_seq_private,
0801     .fini_seq_private   = sock_map_fini_seq_private,
0802     .seq_priv_size      = sizeof(struct sock_map_seq_info),
0803 };
0804 
0805 BTF_ID_LIST_SINGLE(sock_map_btf_ids, struct, bpf_stab)
0806 const struct bpf_map_ops sock_map_ops = {
0807     .map_meta_equal     = bpf_map_meta_equal,
0808     .map_alloc      = sock_map_alloc,
0809     .map_free       = sock_map_free,
0810     .map_get_next_key   = sock_map_get_next_key,
0811     .map_lookup_elem_sys_only = sock_map_lookup_sys,
0812     .map_update_elem    = sock_map_update_elem,
0813     .map_delete_elem    = sock_map_delete_elem,
0814     .map_lookup_elem    = sock_map_lookup,
0815     .map_release_uref   = sock_map_release_progs,
0816     .map_check_btf      = map_check_no_btf,
0817     .map_btf_id     = &sock_map_btf_ids[0],
0818     .iter_seq_info      = &sock_map_iter_seq_info,
0819 };
0820 
0821 struct bpf_shtab_elem {
0822     struct rcu_head rcu;
0823     u32 hash;
0824     struct sock *sk;
0825     struct hlist_node node;
0826     u8 key[];
0827 };
0828 
0829 struct bpf_shtab_bucket {
0830     struct hlist_head head;
0831     raw_spinlock_t lock;
0832 };
0833 
0834 struct bpf_shtab {
0835     struct bpf_map map;
0836     struct bpf_shtab_bucket *buckets;
0837     u32 buckets_num;
0838     u32 elem_size;
0839     struct sk_psock_progs progs;
0840     atomic_t count;
0841 };
0842 
0843 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
0844 {
0845     return jhash(key, len, 0);
0846 }
0847 
0848 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab,
0849                             u32 hash)
0850 {
0851     return &htab->buckets[hash & (htab->buckets_num - 1)];
0852 }
0853 
0854 static struct bpf_shtab_elem *
0855 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
0856               u32 key_size)
0857 {
0858     struct bpf_shtab_elem *elem;
0859 
0860     hlist_for_each_entry_rcu(elem, head, node) {
0861         if (elem->hash == hash &&
0862             !memcmp(&elem->key, key, key_size))
0863             return elem;
0864     }
0865 
0866     return NULL;
0867 }
0868 
0869 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
0870 {
0871     struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
0872     u32 key_size = map->key_size, hash;
0873     struct bpf_shtab_bucket *bucket;
0874     struct bpf_shtab_elem *elem;
0875 
0876     WARN_ON_ONCE(!rcu_read_lock_held());
0877 
0878     hash = sock_hash_bucket_hash(key, key_size);
0879     bucket = sock_hash_select_bucket(htab, hash);
0880     elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
0881 
0882     return elem ? elem->sk : NULL;
0883 }
0884 
0885 static void sock_hash_free_elem(struct bpf_shtab *htab,
0886                 struct bpf_shtab_elem *elem)
0887 {
0888     atomic_dec(&htab->count);
0889     kfree_rcu(elem, rcu);
0890 }
0891 
0892 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
0893                        void *link_raw)
0894 {
0895     struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
0896     struct bpf_shtab_elem *elem_probe, *elem = link_raw;
0897     struct bpf_shtab_bucket *bucket;
0898 
0899     WARN_ON_ONCE(!rcu_read_lock_held());
0900     bucket = sock_hash_select_bucket(htab, elem->hash);
0901 
0902     /* elem may be deleted in parallel from the map, but access here
0903      * is okay since it's going away only after RCU grace period.
0904      * However, we need to check whether it's still present.
0905      */
0906     raw_spin_lock_bh(&bucket->lock);
0907     elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
0908                            elem->key, map->key_size);
0909     if (elem_probe && elem_probe == elem) {
0910         hlist_del_rcu(&elem->node);
0911         sock_map_unref(elem->sk, elem);
0912         sock_hash_free_elem(htab, elem);
0913     }
0914     raw_spin_unlock_bh(&bucket->lock);
0915 }
0916 
0917 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
0918 {
0919     struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
0920     u32 hash, key_size = map->key_size;
0921     struct bpf_shtab_bucket *bucket;
0922     struct bpf_shtab_elem *elem;
0923     int ret = -ENOENT;
0924 
0925     hash = sock_hash_bucket_hash(key, key_size);
0926     bucket = sock_hash_select_bucket(htab, hash);
0927 
0928     raw_spin_lock_bh(&bucket->lock);
0929     elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
0930     if (elem) {
0931         hlist_del_rcu(&elem->node);
0932         sock_map_unref(elem->sk, elem);
0933         sock_hash_free_elem(htab, elem);
0934         ret = 0;
0935     }
0936     raw_spin_unlock_bh(&bucket->lock);
0937     return ret;
0938 }
0939 
0940 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab,
0941                            void *key, u32 key_size,
0942                            u32 hash, struct sock *sk,
0943                            struct bpf_shtab_elem *old)
0944 {
0945     struct bpf_shtab_elem *new;
0946 
0947     if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
0948         if (!old) {
0949             atomic_dec(&htab->count);
0950             return ERR_PTR(-E2BIG);
0951         }
0952     }
0953 
0954     new = bpf_map_kmalloc_node(&htab->map, htab->elem_size,
0955                    GFP_ATOMIC | __GFP_NOWARN,
0956                    htab->map.numa_node);
0957     if (!new) {
0958         atomic_dec(&htab->count);
0959         return ERR_PTR(-ENOMEM);
0960     }
0961     memcpy(new->key, key, key_size);
0962     new->sk = sk;
0963     new->hash = hash;
0964     return new;
0965 }
0966 
0967 static int sock_hash_update_common(struct bpf_map *map, void *key,
0968                    struct sock *sk, u64 flags)
0969 {
0970     struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
0971     u32 key_size = map->key_size, hash;
0972     struct bpf_shtab_elem *elem, *elem_new;
0973     struct bpf_shtab_bucket *bucket;
0974     struct sk_psock_link *link;
0975     struct sk_psock *psock;
0976     int ret;
0977 
0978     WARN_ON_ONCE(!rcu_read_lock_held());
0979     if (unlikely(flags > BPF_EXIST))
0980         return -EINVAL;
0981 
0982     link = sk_psock_init_link();
0983     if (!link)
0984         return -ENOMEM;
0985 
0986     ret = sock_map_link(map, sk);
0987     if (ret < 0)
0988         goto out_free;
0989 
0990     psock = sk_psock(sk);
0991     WARN_ON_ONCE(!psock);
0992 
0993     hash = sock_hash_bucket_hash(key, key_size);
0994     bucket = sock_hash_select_bucket(htab, hash);
0995 
0996     raw_spin_lock_bh(&bucket->lock);
0997     elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
0998     if (elem && flags == BPF_NOEXIST) {
0999         ret = -EEXIST;
1000         goto out_unlock;
1001     } else if (!elem && flags == BPF_EXIST) {
1002         ret = -ENOENT;
1003         goto out_unlock;
1004     }
1005 
1006     elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
1007     if (IS_ERR(elem_new)) {
1008         ret = PTR_ERR(elem_new);
1009         goto out_unlock;
1010     }
1011 
1012     sock_map_add_link(psock, link, map, elem_new);
1013     /* Add new element to the head of the list, so that
1014      * concurrent search will find it before old elem.
1015      */
1016     hlist_add_head_rcu(&elem_new->node, &bucket->head);
1017     if (elem) {
1018         hlist_del_rcu(&elem->node);
1019         sock_map_unref(elem->sk, elem);
1020         sock_hash_free_elem(htab, elem);
1021     }
1022     raw_spin_unlock_bh(&bucket->lock);
1023     return 0;
1024 out_unlock:
1025     raw_spin_unlock_bh(&bucket->lock);
1026     sk_psock_put(sk, psock);
1027 out_free:
1028     sk_psock_free_link(link);
1029     return ret;
1030 }
1031 
1032 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
1033                   void *key_next)
1034 {
1035     struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1036     struct bpf_shtab_elem *elem, *elem_next;
1037     u32 hash, key_size = map->key_size;
1038     struct hlist_head *head;
1039     int i = 0;
1040 
1041     if (!key)
1042         goto find_first_elem;
1043     hash = sock_hash_bucket_hash(key, key_size);
1044     head = &sock_hash_select_bucket(htab, hash)->head;
1045     elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
1046     if (!elem)
1047         goto find_first_elem;
1048 
1049     elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)),
1050                      struct bpf_shtab_elem, node);
1051     if (elem_next) {
1052         memcpy(key_next, elem_next->key, key_size);
1053         return 0;
1054     }
1055 
1056     i = hash & (htab->buckets_num - 1);
1057     i++;
1058 find_first_elem:
1059     for (; i < htab->buckets_num; i++) {
1060         head = &sock_hash_select_bucket(htab, i)->head;
1061         elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)),
1062                          struct bpf_shtab_elem, node);
1063         if (elem_next) {
1064             memcpy(key_next, elem_next->key, key_size);
1065             return 0;
1066         }
1067     }
1068 
1069     return -ENOENT;
1070 }
1071 
1072 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
1073 {
1074     struct bpf_shtab *htab;
1075     int i, err;
1076 
1077     if (!capable(CAP_NET_ADMIN))
1078         return ERR_PTR(-EPERM);
1079     if (attr->max_entries == 0 ||
1080         attr->key_size    == 0 ||
1081         (attr->value_size != sizeof(u32) &&
1082          attr->value_size != sizeof(u64)) ||
1083         attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1084         return ERR_PTR(-EINVAL);
1085     if (attr->key_size > MAX_BPF_STACK)
1086         return ERR_PTR(-E2BIG);
1087 
1088     htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT);
1089     if (!htab)
1090         return ERR_PTR(-ENOMEM);
1091 
1092     bpf_map_init_from_attr(&htab->map, attr);
1093 
1094     htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
1095     htab->elem_size = sizeof(struct bpf_shtab_elem) +
1096               round_up(htab->map.key_size, 8);
1097     if (htab->buckets_num == 0 ||
1098         htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) {
1099         err = -EINVAL;
1100         goto free_htab;
1101     }
1102 
1103     htab->buckets = bpf_map_area_alloc(htab->buckets_num *
1104                        sizeof(struct bpf_shtab_bucket),
1105                        htab->map.numa_node);
1106     if (!htab->buckets) {
1107         err = -ENOMEM;
1108         goto free_htab;
1109     }
1110 
1111     for (i = 0; i < htab->buckets_num; i++) {
1112         INIT_HLIST_HEAD(&htab->buckets[i].head);
1113         raw_spin_lock_init(&htab->buckets[i].lock);
1114     }
1115 
1116     return &htab->map;
1117 free_htab:
1118     kfree(htab);
1119     return ERR_PTR(err);
1120 }
1121 
1122 static void sock_hash_free(struct bpf_map *map)
1123 {
1124     struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1125     struct bpf_shtab_bucket *bucket;
1126     struct hlist_head unlink_list;
1127     struct bpf_shtab_elem *elem;
1128     struct hlist_node *node;
1129     int i;
1130 
1131     /* After the sync no updates or deletes will be in-flight so it
1132      * is safe to walk map and remove entries without risking a race
1133      * in EEXIST update case.
1134      */
1135     synchronize_rcu();
1136     for (i = 0; i < htab->buckets_num; i++) {
1137         bucket = sock_hash_select_bucket(htab, i);
1138 
1139         /* We are racing with sock_hash_delete_from_link to
1140          * enter the spin-lock critical section. Every socket on
1141          * the list is still linked to sockhash. Since link
1142          * exists, psock exists and holds a ref to socket. That
1143          * lets us to grab a socket ref too.
1144          */
1145         raw_spin_lock_bh(&bucket->lock);
1146         hlist_for_each_entry(elem, &bucket->head, node)
1147             sock_hold(elem->sk);
1148         hlist_move_list(&bucket->head, &unlink_list);
1149         raw_spin_unlock_bh(&bucket->lock);
1150 
1151         /* Process removed entries out of atomic context to
1152          * block for socket lock before deleting the psock's
1153          * link to sockhash.
1154          */
1155         hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
1156             hlist_del(&elem->node);
1157             lock_sock(elem->sk);
1158             rcu_read_lock();
1159             sock_map_unref(elem->sk, elem);
1160             rcu_read_unlock();
1161             release_sock(elem->sk);
1162             sock_put(elem->sk);
1163             sock_hash_free_elem(htab, elem);
1164         }
1165     }
1166 
1167     /* wait for psock readers accessing its map link */
1168     synchronize_rcu();
1169 
1170     bpf_map_area_free(htab->buckets);
1171     kfree(htab);
1172 }
1173 
1174 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
1175 {
1176     struct sock *sk;
1177 
1178     if (map->value_size != sizeof(u64))
1179         return ERR_PTR(-ENOSPC);
1180 
1181     sk = __sock_hash_lookup_elem(map, key);
1182     if (!sk)
1183         return ERR_PTR(-ENOENT);
1184 
1185     __sock_gen_cookie(sk);
1186     return &sk->sk_cookie;
1187 }
1188 
1189 static void *sock_hash_lookup(struct bpf_map *map, void *key)
1190 {
1191     struct sock *sk;
1192 
1193     sk = __sock_hash_lookup_elem(map, key);
1194     if (!sk)
1195         return NULL;
1196     if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
1197         return NULL;
1198     return sk;
1199 }
1200 
1201 static void sock_hash_release_progs(struct bpf_map *map)
1202 {
1203     psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs);
1204 }
1205 
1206 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
1207        struct bpf_map *, map, void *, key, u64, flags)
1208 {
1209     WARN_ON_ONCE(!rcu_read_lock_held());
1210 
1211     if (likely(sock_map_sk_is_suitable(sops->sk) &&
1212            sock_map_op_okay(sops)))
1213         return sock_hash_update_common(map, key, sops->sk, flags);
1214     return -EOPNOTSUPP;
1215 }
1216 
1217 const struct bpf_func_proto bpf_sock_hash_update_proto = {
1218     .func       = bpf_sock_hash_update,
1219     .gpl_only   = false,
1220     .pkt_access = true,
1221     .ret_type   = RET_INTEGER,
1222     .arg1_type  = ARG_PTR_TO_CTX,
1223     .arg2_type  = ARG_CONST_MAP_PTR,
1224     .arg3_type  = ARG_PTR_TO_MAP_KEY,
1225     .arg4_type  = ARG_ANYTHING,
1226 };
1227 
1228 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
1229        struct bpf_map *, map, void *, key, u64, flags)
1230 {
1231     struct sock *sk;
1232 
1233     if (unlikely(flags & ~(BPF_F_INGRESS)))
1234         return SK_DROP;
1235 
1236     sk = __sock_hash_lookup_elem(map, key);
1237     if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1238         return SK_DROP;
1239 
1240     skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
1241     return SK_PASS;
1242 }
1243 
1244 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1245     .func           = bpf_sk_redirect_hash,
1246     .gpl_only       = false,
1247     .ret_type       = RET_INTEGER,
1248     .arg1_type  = ARG_PTR_TO_CTX,
1249     .arg2_type      = ARG_CONST_MAP_PTR,
1250     .arg3_type      = ARG_PTR_TO_MAP_KEY,
1251     .arg4_type      = ARG_ANYTHING,
1252 };
1253 
1254 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1255        struct bpf_map *, map, void *, key, u64, flags)
1256 {
1257     struct sock *sk;
1258 
1259     if (unlikely(flags & ~(BPF_F_INGRESS)))
1260         return SK_DROP;
1261 
1262     sk = __sock_hash_lookup_elem(map, key);
1263     if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1264         return SK_DROP;
1265 
1266     msg->flags = flags;
1267     msg->sk_redir = sk;
1268     return SK_PASS;
1269 }
1270 
1271 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1272     .func           = bpf_msg_redirect_hash,
1273     .gpl_only       = false,
1274     .ret_type       = RET_INTEGER,
1275     .arg1_type  = ARG_PTR_TO_CTX,
1276     .arg2_type      = ARG_CONST_MAP_PTR,
1277     .arg3_type      = ARG_PTR_TO_MAP_KEY,
1278     .arg4_type      = ARG_ANYTHING,
1279 };
1280 
1281 struct sock_hash_seq_info {
1282     struct bpf_map *map;
1283     struct bpf_shtab *htab;
1284     u32 bucket_id;
1285 };
1286 
1287 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info,
1288                      struct bpf_shtab_elem *prev_elem)
1289 {
1290     const struct bpf_shtab *htab = info->htab;
1291     struct bpf_shtab_bucket *bucket;
1292     struct bpf_shtab_elem *elem;
1293     struct hlist_node *node;
1294 
1295     /* try to find next elem in the same bucket */
1296     if (prev_elem) {
1297         node = rcu_dereference(hlist_next_rcu(&prev_elem->node));
1298         elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
1299         if (elem)
1300             return elem;
1301 
1302         /* no more elements, continue in the next bucket */
1303         info->bucket_id++;
1304     }
1305 
1306     for (; info->bucket_id < htab->buckets_num; info->bucket_id++) {
1307         bucket = &htab->buckets[info->bucket_id];
1308         node = rcu_dereference(hlist_first_rcu(&bucket->head));
1309         elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
1310         if (elem)
1311             return elem;
1312     }
1313 
1314     return NULL;
1315 }
1316 
1317 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos)
1318     __acquires(rcu)
1319 {
1320     struct sock_hash_seq_info *info = seq->private;
1321 
1322     if (*pos == 0)
1323         ++*pos;
1324 
1325     /* pairs with sock_hash_seq_stop */
1326     rcu_read_lock();
1327     return sock_hash_seq_find_next(info, NULL);
1328 }
1329 
1330 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos)
1331     __must_hold(rcu)
1332 {
1333     struct sock_hash_seq_info *info = seq->private;
1334 
1335     ++*pos;
1336     return sock_hash_seq_find_next(info, v);
1337 }
1338 
1339 static int sock_hash_seq_show(struct seq_file *seq, void *v)
1340     __must_hold(rcu)
1341 {
1342     struct sock_hash_seq_info *info = seq->private;
1343     struct bpf_iter__sockmap ctx = {};
1344     struct bpf_shtab_elem *elem = v;
1345     struct bpf_iter_meta meta;
1346     struct bpf_prog *prog;
1347 
1348     meta.seq = seq;
1349     prog = bpf_iter_get_info(&meta, !elem);
1350     if (!prog)
1351         return 0;
1352 
1353     ctx.meta = &meta;
1354     ctx.map = info->map;
1355     if (elem) {
1356         ctx.key = elem->key;
1357         ctx.sk = elem->sk;
1358     }
1359 
1360     return bpf_iter_run_prog(prog, &ctx);
1361 }
1362 
1363 static void sock_hash_seq_stop(struct seq_file *seq, void *v)
1364     __releases(rcu)
1365 {
1366     if (!v)
1367         (void)sock_hash_seq_show(seq, NULL);
1368 
1369     /* pairs with sock_hash_seq_start */
1370     rcu_read_unlock();
1371 }
1372 
1373 static const struct seq_operations sock_hash_seq_ops = {
1374     .start  = sock_hash_seq_start,
1375     .next   = sock_hash_seq_next,
1376     .stop   = sock_hash_seq_stop,
1377     .show   = sock_hash_seq_show,
1378 };
1379 
1380 static int sock_hash_init_seq_private(void *priv_data,
1381                       struct bpf_iter_aux_info *aux)
1382 {
1383     struct sock_hash_seq_info *info = priv_data;
1384 
1385     bpf_map_inc_with_uref(aux->map);
1386     info->map = aux->map;
1387     info->htab = container_of(aux->map, struct bpf_shtab, map);
1388     return 0;
1389 }
1390 
1391 static void sock_hash_fini_seq_private(void *priv_data)
1392 {
1393     struct sock_hash_seq_info *info = priv_data;
1394 
1395     bpf_map_put_with_uref(info->map);
1396 }
1397 
1398 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = {
1399     .seq_ops        = &sock_hash_seq_ops,
1400     .init_seq_private   = sock_hash_init_seq_private,
1401     .fini_seq_private   = sock_hash_fini_seq_private,
1402     .seq_priv_size      = sizeof(struct sock_hash_seq_info),
1403 };
1404 
1405 BTF_ID_LIST_SINGLE(sock_hash_map_btf_ids, struct, bpf_shtab)
1406 const struct bpf_map_ops sock_hash_ops = {
1407     .map_meta_equal     = bpf_map_meta_equal,
1408     .map_alloc      = sock_hash_alloc,
1409     .map_free       = sock_hash_free,
1410     .map_get_next_key   = sock_hash_get_next_key,
1411     .map_update_elem    = sock_map_update_elem,
1412     .map_delete_elem    = sock_hash_delete_elem,
1413     .map_lookup_elem    = sock_hash_lookup,
1414     .map_lookup_elem_sys_only = sock_hash_lookup_sys,
1415     .map_release_uref   = sock_hash_release_progs,
1416     .map_check_btf      = map_check_no_btf,
1417     .map_btf_id     = &sock_hash_map_btf_ids[0],
1418     .iter_seq_info      = &sock_hash_iter_seq_info,
1419 };
1420 
1421 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1422 {
1423     switch (map->map_type) {
1424     case BPF_MAP_TYPE_SOCKMAP:
1425         return &container_of(map, struct bpf_stab, map)->progs;
1426     case BPF_MAP_TYPE_SOCKHASH:
1427         return &container_of(map, struct bpf_shtab, map)->progs;
1428     default:
1429         break;
1430     }
1431 
1432     return NULL;
1433 }
1434 
1435 static int sock_map_prog_lookup(struct bpf_map *map, struct bpf_prog ***pprog,
1436                 u32 which)
1437 {
1438     struct sk_psock_progs *progs = sock_map_progs(map);
1439 
1440     if (!progs)
1441         return -EOPNOTSUPP;
1442 
1443     switch (which) {
1444     case BPF_SK_MSG_VERDICT:
1445         *pprog = &progs->msg_parser;
1446         break;
1447 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
1448     case BPF_SK_SKB_STREAM_PARSER:
1449         *pprog = &progs->stream_parser;
1450         break;
1451 #endif
1452     case BPF_SK_SKB_STREAM_VERDICT:
1453         if (progs->skb_verdict)
1454             return -EBUSY;
1455         *pprog = &progs->stream_verdict;
1456         break;
1457     case BPF_SK_SKB_VERDICT:
1458         if (progs->stream_verdict)
1459             return -EBUSY;
1460         *pprog = &progs->skb_verdict;
1461         break;
1462     default:
1463         return -EOPNOTSUPP;
1464     }
1465 
1466     return 0;
1467 }
1468 
1469 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1470                 struct bpf_prog *old, u32 which)
1471 {
1472     struct bpf_prog **pprog;
1473     int ret;
1474 
1475     ret = sock_map_prog_lookup(map, &pprog, which);
1476     if (ret)
1477         return ret;
1478 
1479     if (old)
1480         return psock_replace_prog(pprog, prog, old);
1481 
1482     psock_set_prog(pprog, prog);
1483     return 0;
1484 }
1485 
1486 int sock_map_bpf_prog_query(const union bpf_attr *attr,
1487                 union bpf_attr __user *uattr)
1488 {
1489     __u32 __user *prog_ids = u64_to_user_ptr(attr->query.prog_ids);
1490     u32 prog_cnt = 0, flags = 0, ufd = attr->target_fd;
1491     struct bpf_prog **pprog;
1492     struct bpf_prog *prog;
1493     struct bpf_map *map;
1494     struct fd f;
1495     u32 id = 0;
1496     int ret;
1497 
1498     if (attr->query.query_flags)
1499         return -EINVAL;
1500 
1501     f = fdget(ufd);
1502     map = __bpf_map_get(f);
1503     if (IS_ERR(map))
1504         return PTR_ERR(map);
1505 
1506     rcu_read_lock();
1507 
1508     ret = sock_map_prog_lookup(map, &pprog, attr->query.attach_type);
1509     if (ret)
1510         goto end;
1511 
1512     prog = *pprog;
1513     prog_cnt = !prog ? 0 : 1;
1514 
1515     if (!attr->query.prog_cnt || !prog_ids || !prog_cnt)
1516         goto end;
1517 
1518     /* we do not hold the refcnt, the bpf prog may be released
1519      * asynchronously and the id would be set to 0.
1520      */
1521     id = data_race(prog->aux->id);
1522     if (id == 0)
1523         prog_cnt = 0;
1524 
1525 end:
1526     rcu_read_unlock();
1527 
1528     if (copy_to_user(&uattr->query.attach_flags, &flags, sizeof(flags)) ||
1529         (id != 0 && copy_to_user(prog_ids, &id, sizeof(u32))) ||
1530         copy_to_user(&uattr->query.prog_cnt, &prog_cnt, sizeof(prog_cnt)))
1531         ret = -EFAULT;
1532 
1533     fdput(f);
1534     return ret;
1535 }
1536 
1537 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
1538 {
1539     switch (link->map->map_type) {
1540     case BPF_MAP_TYPE_SOCKMAP:
1541         return sock_map_delete_from_link(link->map, sk,
1542                          link->link_raw);
1543     case BPF_MAP_TYPE_SOCKHASH:
1544         return sock_hash_delete_from_link(link->map, sk,
1545                           link->link_raw);
1546     default:
1547         break;
1548     }
1549 }
1550 
1551 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
1552 {
1553     struct sk_psock_link *link;
1554 
1555     while ((link = sk_psock_link_pop(psock))) {
1556         sock_map_unlink(sk, link);
1557         sk_psock_free_link(link);
1558     }
1559 }
1560 
1561 void sock_map_unhash(struct sock *sk)
1562 {
1563     void (*saved_unhash)(struct sock *sk);
1564     struct sk_psock *psock;
1565 
1566     rcu_read_lock();
1567     psock = sk_psock(sk);
1568     if (unlikely(!psock)) {
1569         rcu_read_unlock();
1570         if (sk->sk_prot->unhash)
1571             sk->sk_prot->unhash(sk);
1572         return;
1573     }
1574 
1575     saved_unhash = psock->saved_unhash;
1576     sock_map_remove_links(sk, psock);
1577     rcu_read_unlock();
1578     saved_unhash(sk);
1579 }
1580 EXPORT_SYMBOL_GPL(sock_map_unhash);
1581 
1582 void sock_map_destroy(struct sock *sk)
1583 {
1584     void (*saved_destroy)(struct sock *sk);
1585     struct sk_psock *psock;
1586 
1587     rcu_read_lock();
1588     psock = sk_psock_get(sk);
1589     if (unlikely(!psock)) {
1590         rcu_read_unlock();
1591         if (sk->sk_prot->destroy)
1592             sk->sk_prot->destroy(sk);
1593         return;
1594     }
1595 
1596     saved_destroy = psock->saved_destroy;
1597     sock_map_remove_links(sk, psock);
1598     rcu_read_unlock();
1599     sk_psock_stop(psock, false);
1600     sk_psock_put(sk, psock);
1601     saved_destroy(sk);
1602 }
1603 EXPORT_SYMBOL_GPL(sock_map_destroy);
1604 
1605 void sock_map_close(struct sock *sk, long timeout)
1606 {
1607     void (*saved_close)(struct sock *sk, long timeout);
1608     struct sk_psock *psock;
1609 
1610     lock_sock(sk);
1611     rcu_read_lock();
1612     psock = sk_psock_get(sk);
1613     if (unlikely(!psock)) {
1614         rcu_read_unlock();
1615         release_sock(sk);
1616         return sk->sk_prot->close(sk, timeout);
1617     }
1618 
1619     saved_close = psock->saved_close;
1620     sock_map_remove_links(sk, psock);
1621     rcu_read_unlock();
1622     sk_psock_stop(psock, true);
1623     sk_psock_put(sk, psock);
1624     release_sock(sk);
1625     saved_close(sk, timeout);
1626 }
1627 EXPORT_SYMBOL_GPL(sock_map_close);
1628 
1629 static int sock_map_iter_attach_target(struct bpf_prog *prog,
1630                        union bpf_iter_link_info *linfo,
1631                        struct bpf_iter_aux_info *aux)
1632 {
1633     struct bpf_map *map;
1634     int err = -EINVAL;
1635 
1636     if (!linfo->map.map_fd)
1637         return -EBADF;
1638 
1639     map = bpf_map_get_with_uref(linfo->map.map_fd);
1640     if (IS_ERR(map))
1641         return PTR_ERR(map);
1642 
1643     if (map->map_type != BPF_MAP_TYPE_SOCKMAP &&
1644         map->map_type != BPF_MAP_TYPE_SOCKHASH)
1645         goto put_map;
1646 
1647     if (prog->aux->max_rdonly_access > map->key_size) {
1648         err = -EACCES;
1649         goto put_map;
1650     }
1651 
1652     aux->map = map;
1653     return 0;
1654 
1655 put_map:
1656     bpf_map_put_with_uref(map);
1657     return err;
1658 }
1659 
1660 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux)
1661 {
1662     bpf_map_put_with_uref(aux->map);
1663 }
1664 
1665 static struct bpf_iter_reg sock_map_iter_reg = {
1666     .target         = "sockmap",
1667     .attach_target      = sock_map_iter_attach_target,
1668     .detach_target      = sock_map_iter_detach_target,
1669     .show_fdinfo        = bpf_iter_map_show_fdinfo,
1670     .fill_link_info     = bpf_iter_map_fill_link_info,
1671     .ctx_arg_info_size  = 2,
1672     .ctx_arg_info       = {
1673         { offsetof(struct bpf_iter__sockmap, key),
1674           PTR_TO_BUF | PTR_MAYBE_NULL | MEM_RDONLY },
1675         { offsetof(struct bpf_iter__sockmap, sk),
1676           PTR_TO_BTF_ID_OR_NULL },
1677     },
1678 };
1679 
1680 static int __init bpf_sockmap_iter_init(void)
1681 {
1682     sock_map_iter_reg.ctx_arg_info[1].btf_id =
1683         btf_sock_ids[BTF_SOCK_TYPE_SOCK];
1684     return bpf_iter_reg_target(&sock_map_iter_reg);
1685 }
1686 late_initcall(bpf_sockmap_iter_init);