Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /*
0003  * Copyright (c) 2018 Facebook
0004  */
0005 #include <linux/bpf.h>
0006 #include <linux/err.h>
0007 #include <linux/sock_diag.h>
0008 #include <net/sock_reuseport.h>
0009 #include <linux/btf_ids.h>
0010 
0011 struct reuseport_array {
0012     struct bpf_map map;
0013     struct sock __rcu *ptrs[];
0014 };
0015 
0016 static struct reuseport_array *reuseport_array(struct bpf_map *map)
0017 {
0018     return (struct reuseport_array *)map;
0019 }
0020 
0021 /* The caller must hold the reuseport_lock */
0022 void bpf_sk_reuseport_detach(struct sock *sk)
0023 {
0024     struct sock __rcu **socks;
0025 
0026     write_lock_bh(&sk->sk_callback_lock);
0027     socks = __locked_read_sk_user_data_with_flags(sk, SK_USER_DATA_BPF);
0028     if (socks) {
0029         WRITE_ONCE(sk->sk_user_data, NULL);
0030         /*
0031          * Do not move this NULL assignment outside of
0032          * sk->sk_callback_lock because there is
0033          * a race with reuseport_array_free()
0034          * which does not hold the reuseport_lock.
0035          */
0036         RCU_INIT_POINTER(*socks, NULL);
0037     }
0038     write_unlock_bh(&sk->sk_callback_lock);
0039 }
0040 
0041 static int reuseport_array_alloc_check(union bpf_attr *attr)
0042 {
0043     if (attr->value_size != sizeof(u32) &&
0044         attr->value_size != sizeof(u64))
0045         return -EINVAL;
0046 
0047     return array_map_alloc_check(attr);
0048 }
0049 
0050 static void *reuseport_array_lookup_elem(struct bpf_map *map, void *key)
0051 {
0052     struct reuseport_array *array = reuseport_array(map);
0053     u32 index = *(u32 *)key;
0054 
0055     if (unlikely(index >= array->map.max_entries))
0056         return NULL;
0057 
0058     return rcu_dereference(array->ptrs[index]);
0059 }
0060 
0061 /* Called from syscall only */
0062 static int reuseport_array_delete_elem(struct bpf_map *map, void *key)
0063 {
0064     struct reuseport_array *array = reuseport_array(map);
0065     u32 index = *(u32 *)key;
0066     struct sock *sk;
0067     int err;
0068 
0069     if (index >= map->max_entries)
0070         return -E2BIG;
0071 
0072     if (!rcu_access_pointer(array->ptrs[index]))
0073         return -ENOENT;
0074 
0075     spin_lock_bh(&reuseport_lock);
0076 
0077     sk = rcu_dereference_protected(array->ptrs[index],
0078                        lockdep_is_held(&reuseport_lock));
0079     if (sk) {
0080         write_lock_bh(&sk->sk_callback_lock);
0081         WRITE_ONCE(sk->sk_user_data, NULL);
0082         RCU_INIT_POINTER(array->ptrs[index], NULL);
0083         write_unlock_bh(&sk->sk_callback_lock);
0084         err = 0;
0085     } else {
0086         err = -ENOENT;
0087     }
0088 
0089     spin_unlock_bh(&reuseport_lock);
0090 
0091     return err;
0092 }
0093 
0094 static void reuseport_array_free(struct bpf_map *map)
0095 {
0096     struct reuseport_array *array = reuseport_array(map);
0097     struct sock *sk;
0098     u32 i;
0099 
0100     /*
0101      * ops->map_*_elem() will not be able to access this
0102      * array now. Hence, this function only races with
0103      * bpf_sk_reuseport_detach() which was triggered by
0104      * close() or disconnect().
0105      *
0106      * This function and bpf_sk_reuseport_detach() are
0107      * both removing sk from "array".  Who removes it
0108      * first does not matter.
0109      *
0110      * The only concern here is bpf_sk_reuseport_detach()
0111      * may access "array" which is being freed here.
0112      * bpf_sk_reuseport_detach() access this "array"
0113      * through sk->sk_user_data _and_ with sk->sk_callback_lock
0114      * held which is enough because this "array" is not freed
0115      * until all sk->sk_user_data has stopped referencing this "array".
0116      *
0117      * Hence, due to the above, taking "reuseport_lock" is not
0118      * needed here.
0119      */
0120 
0121     /*
0122      * Since reuseport_lock is not taken, sk is accessed under
0123      * rcu_read_lock()
0124      */
0125     rcu_read_lock();
0126     for (i = 0; i < map->max_entries; i++) {
0127         sk = rcu_dereference(array->ptrs[i]);
0128         if (sk) {
0129             write_lock_bh(&sk->sk_callback_lock);
0130             /*
0131              * No need for WRITE_ONCE(). At this point,
0132              * no one is reading it without taking the
0133              * sk->sk_callback_lock.
0134              */
0135             sk->sk_user_data = NULL;
0136             write_unlock_bh(&sk->sk_callback_lock);
0137             RCU_INIT_POINTER(array->ptrs[i], NULL);
0138         }
0139     }
0140     rcu_read_unlock();
0141 
0142     /*
0143      * Once reaching here, all sk->sk_user_data is not
0144      * referencing this "array". "array" can be freed now.
0145      */
0146     bpf_map_area_free(array);
0147 }
0148 
0149 static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr)
0150 {
0151     int numa_node = bpf_map_attr_numa_node(attr);
0152     struct reuseport_array *array;
0153 
0154     if (!bpf_capable())
0155         return ERR_PTR(-EPERM);
0156 
0157     /* allocate all map elements and zero-initialize them */
0158     array = bpf_map_area_alloc(struct_size(array, ptrs, attr->max_entries), numa_node);
0159     if (!array)
0160         return ERR_PTR(-ENOMEM);
0161 
0162     /* copy mandatory map attributes */
0163     bpf_map_init_from_attr(&array->map, attr);
0164 
0165     return &array->map;
0166 }
0167 
0168 int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map, void *key,
0169                        void *value)
0170 {
0171     struct sock *sk;
0172     int err;
0173 
0174     if (map->value_size != sizeof(u64))
0175         return -ENOSPC;
0176 
0177     rcu_read_lock();
0178     sk = reuseport_array_lookup_elem(map, key);
0179     if (sk) {
0180         *(u64 *)value = __sock_gen_cookie(sk);
0181         err = 0;
0182     } else {
0183         err = -ENOENT;
0184     }
0185     rcu_read_unlock();
0186 
0187     return err;
0188 }
0189 
0190 static int
0191 reuseport_array_update_check(const struct reuseport_array *array,
0192                  const struct sock *nsk,
0193                  const struct sock *osk,
0194                  const struct sock_reuseport *nsk_reuse,
0195                  u32 map_flags)
0196 {
0197     if (osk && map_flags == BPF_NOEXIST)
0198         return -EEXIST;
0199 
0200     if (!osk && map_flags == BPF_EXIST)
0201         return -ENOENT;
0202 
0203     if (nsk->sk_protocol != IPPROTO_UDP && nsk->sk_protocol != IPPROTO_TCP)
0204         return -ENOTSUPP;
0205 
0206     if (nsk->sk_family != AF_INET && nsk->sk_family != AF_INET6)
0207         return -ENOTSUPP;
0208 
0209     if (nsk->sk_type != SOCK_STREAM && nsk->sk_type != SOCK_DGRAM)
0210         return -ENOTSUPP;
0211 
0212     /*
0213      * sk must be hashed (i.e. listening in the TCP case or binded
0214      * in the UDP case) and
0215      * it must also be a SO_REUSEPORT sk (i.e. reuse cannot be NULL).
0216      *
0217      * Also, sk will be used in bpf helper that is protected by
0218      * rcu_read_lock().
0219      */
0220     if (!sock_flag(nsk, SOCK_RCU_FREE) || !sk_hashed(nsk) || !nsk_reuse)
0221         return -EINVAL;
0222 
0223     /* READ_ONCE because the sk->sk_callback_lock may not be held here */
0224     if (READ_ONCE(nsk->sk_user_data))
0225         return -EBUSY;
0226 
0227     return 0;
0228 }
0229 
0230 /*
0231  * Called from syscall only.
0232  * The "nsk" in the fd refcnt.
0233  * The "osk" and "reuse" are protected by reuseport_lock.
0234  */
0235 int bpf_fd_reuseport_array_update_elem(struct bpf_map *map, void *key,
0236                        void *value, u64 map_flags)
0237 {
0238     struct reuseport_array *array = reuseport_array(map);
0239     struct sock *free_osk = NULL, *osk, *nsk;
0240     struct sock_reuseport *reuse;
0241     u32 index = *(u32 *)key;
0242     uintptr_t sk_user_data;
0243     struct socket *socket;
0244     int err, fd;
0245 
0246     if (map_flags > BPF_EXIST)
0247         return -EINVAL;
0248 
0249     if (index >= map->max_entries)
0250         return -E2BIG;
0251 
0252     if (map->value_size == sizeof(u64)) {
0253         u64 fd64 = *(u64 *)value;
0254 
0255         if (fd64 > S32_MAX)
0256             return -EINVAL;
0257         fd = fd64;
0258     } else {
0259         fd = *(int *)value;
0260     }
0261 
0262     socket = sockfd_lookup(fd, &err);
0263     if (!socket)
0264         return err;
0265 
0266     nsk = socket->sk;
0267     if (!nsk) {
0268         err = -EINVAL;
0269         goto put_file;
0270     }
0271 
0272     /* Quick checks before taking reuseport_lock */
0273     err = reuseport_array_update_check(array, nsk,
0274                        rcu_access_pointer(array->ptrs[index]),
0275                        rcu_access_pointer(nsk->sk_reuseport_cb),
0276                        map_flags);
0277     if (err)
0278         goto put_file;
0279 
0280     spin_lock_bh(&reuseport_lock);
0281     /*
0282      * Some of the checks only need reuseport_lock
0283      * but it is done under sk_callback_lock also
0284      * for simplicity reason.
0285      */
0286     write_lock_bh(&nsk->sk_callback_lock);
0287 
0288     osk = rcu_dereference_protected(array->ptrs[index],
0289                     lockdep_is_held(&reuseport_lock));
0290     reuse = rcu_dereference_protected(nsk->sk_reuseport_cb,
0291                       lockdep_is_held(&reuseport_lock));
0292     err = reuseport_array_update_check(array, nsk, osk, reuse, map_flags);
0293     if (err)
0294         goto put_file_unlock;
0295 
0296     sk_user_data = (uintptr_t)&array->ptrs[index] | SK_USER_DATA_NOCOPY |
0297         SK_USER_DATA_BPF;
0298     WRITE_ONCE(nsk->sk_user_data, (void *)sk_user_data);
0299     rcu_assign_pointer(array->ptrs[index], nsk);
0300     free_osk = osk;
0301     err = 0;
0302 
0303 put_file_unlock:
0304     write_unlock_bh(&nsk->sk_callback_lock);
0305 
0306     if (free_osk) {
0307         write_lock_bh(&free_osk->sk_callback_lock);
0308         WRITE_ONCE(free_osk->sk_user_data, NULL);
0309         write_unlock_bh(&free_osk->sk_callback_lock);
0310     }
0311 
0312     spin_unlock_bh(&reuseport_lock);
0313 put_file:
0314     fput(socket->file);
0315     return err;
0316 }
0317 
0318 /* Called from syscall */
0319 static int reuseport_array_get_next_key(struct bpf_map *map, void *key,
0320                     void *next_key)
0321 {
0322     struct reuseport_array *array = reuseport_array(map);
0323     u32 index = key ? *(u32 *)key : U32_MAX;
0324     u32 *next = (u32 *)next_key;
0325 
0326     if (index >= array->map.max_entries) {
0327         *next = 0;
0328         return 0;
0329     }
0330 
0331     if (index == array->map.max_entries - 1)
0332         return -ENOENT;
0333 
0334     *next = index + 1;
0335     return 0;
0336 }
0337 
0338 BTF_ID_LIST_SINGLE(reuseport_array_map_btf_ids, struct, reuseport_array)
0339 const struct bpf_map_ops reuseport_array_ops = {
0340     .map_meta_equal = bpf_map_meta_equal,
0341     .map_alloc_check = reuseport_array_alloc_check,
0342     .map_alloc = reuseport_array_alloc,
0343     .map_free = reuseport_array_free,
0344     .map_lookup_elem = reuseport_array_lookup_elem,
0345     .map_get_next_key = reuseport_array_get_next_key,
0346     .map_delete_elem = reuseport_array_delete_elem,
0347     .map_btf_id = &reuseport_array_map_btf_ids[0],
0348 };