Back to home page

OSCL-LXR

 
 

    


0001 /* SPDX-License-Identifier: GPL-2.0 */
0002 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
0003 
0004 #ifndef _LINUX_SKMSG_H
0005 #define _LINUX_SKMSG_H
0006 
0007 #include <linux/bpf.h>
0008 #include <linux/filter.h>
0009 #include <linux/scatterlist.h>
0010 #include <linux/skbuff.h>
0011 
0012 #include <net/sock.h>
0013 #include <net/tcp.h>
0014 #include <net/strparser.h>
0015 
0016 #define MAX_MSG_FRAGS           MAX_SKB_FRAGS
0017 #define NR_MSG_FRAG_IDS         (MAX_MSG_FRAGS + 1)
0018 
0019 enum __sk_action {
0020     __SK_DROP = 0,
0021     __SK_PASS,
0022     __SK_REDIRECT,
0023     __SK_NONE,
0024 };
0025 
0026 struct sk_msg_sg {
0027     u32             start;
0028     u32             curr;
0029     u32             end;
0030     u32             size;
0031     u32             copybreak;
0032     DECLARE_BITMAP(copy, MAX_MSG_FRAGS + 2);
0033     /* The extra two elements:
0034      * 1) used for chaining the front and sections when the list becomes
0035      *    partitioned (e.g. end < start). The crypto APIs require the
0036      *    chaining;
0037      * 2) to chain tailer SG entries after the message.
0038      */
0039     struct scatterlist      data[MAX_MSG_FRAGS + 2];
0040 };
0041 
0042 /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
0043 struct sk_msg {
0044     struct sk_msg_sg        sg;
0045     void                *data;
0046     void                *data_end;
0047     u32             apply_bytes;
0048     u32             cork_bytes;
0049     u32             flags;
0050     struct sk_buff          *skb;
0051     struct sock         *sk_redir;
0052     struct sock         *sk;
0053     struct list_head        list;
0054 };
0055 
0056 struct sk_psock_progs {
0057     struct bpf_prog         *msg_parser;
0058     struct bpf_prog         *stream_parser;
0059     struct bpf_prog         *stream_verdict;
0060     struct bpf_prog         *skb_verdict;
0061 };
0062 
0063 enum sk_psock_state_bits {
0064     SK_PSOCK_TX_ENABLED,
0065 };
0066 
0067 struct sk_psock_link {
0068     struct list_head        list;
0069     struct bpf_map          *map;
0070     void                *link_raw;
0071 };
0072 
0073 struct sk_psock_work_state {
0074     struct sk_buff          *skb;
0075     u32             len;
0076     u32             off;
0077 };
0078 
0079 struct sk_psock {
0080     struct sock         *sk;
0081     struct sock         *sk_redir;
0082     u32             apply_bytes;
0083     u32             cork_bytes;
0084     u32             eval;
0085     struct sk_msg           *cork;
0086     struct sk_psock_progs       progs;
0087 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
0088     struct strparser        strp;
0089 #endif
0090     struct sk_buff_head     ingress_skb;
0091     struct list_head        ingress_msg;
0092     spinlock_t          ingress_lock;
0093     unsigned long           state;
0094     struct list_head        link;
0095     spinlock_t          link_lock;
0096     refcount_t          refcnt;
0097     void (*saved_unhash)(struct sock *sk);
0098     void (*saved_destroy)(struct sock *sk);
0099     void (*saved_close)(struct sock *sk, long timeout);
0100     void (*saved_write_space)(struct sock *sk);
0101     void (*saved_data_ready)(struct sock *sk);
0102     int  (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
0103                      bool restore);
0104     struct proto            *sk_proto;
0105     struct mutex            work_mutex;
0106     struct sk_psock_work_state  work_state;
0107     struct work_struct      work;
0108     struct rcu_work         rwork;
0109 };
0110 
0111 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
0112          int elem_first_coalesce);
0113 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
0114          u32 off, u32 len);
0115 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
0116 int sk_msg_free(struct sock *sk, struct sk_msg *msg);
0117 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
0118 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
0119 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
0120                   u32 bytes);
0121 
0122 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
0123 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
0124 
0125 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
0126                   struct sk_msg *msg, u32 bytes);
0127 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
0128                  struct sk_msg *msg, u32 bytes);
0129 int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
0130            int len, int flags);
0131 bool sk_msg_is_readable(struct sock *sk);
0132 
0133 static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
0134 {
0135     WARN_ON(i == msg->sg.end && bytes);
0136 }
0137 
0138 static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
0139 {
0140     if (psock->apply_bytes) {
0141         if (psock->apply_bytes < bytes)
0142             psock->apply_bytes = 0;
0143         else
0144             psock->apply_bytes -= bytes;
0145     }
0146 }
0147 
0148 static inline u32 sk_msg_iter_dist(u32 start, u32 end)
0149 {
0150     return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
0151 }
0152 
0153 #define sk_msg_iter_var_prev(var)           \
0154     do {                        \
0155         if (var == 0)               \
0156             var = NR_MSG_FRAG_IDS - 1;  \
0157         else                    \
0158             var--;              \
0159     } while (0)
0160 
0161 #define sk_msg_iter_var_next(var)           \
0162     do {                        \
0163         var++;                  \
0164         if (var == NR_MSG_FRAG_IDS)     \
0165             var = 0;            \
0166     } while (0)
0167 
0168 #define sk_msg_iter_prev(msg, which)            \
0169     sk_msg_iter_var_prev(msg->sg.which)
0170 
0171 #define sk_msg_iter_next(msg, which)            \
0172     sk_msg_iter_var_next(msg->sg.which)
0173 
0174 static inline void sk_msg_init(struct sk_msg *msg)
0175 {
0176     BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
0177     memset(msg, 0, sizeof(*msg));
0178     sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
0179 }
0180 
0181 static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
0182                    int which, u32 size)
0183 {
0184     dst->sg.data[which] = src->sg.data[which];
0185     dst->sg.data[which].length  = size;
0186     dst->sg.size           += size;
0187     src->sg.size           -= size;
0188     src->sg.data[which].length -= size;
0189     src->sg.data[which].offset += size;
0190 }
0191 
0192 static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
0193 {
0194     memcpy(dst, src, sizeof(*src));
0195     sk_msg_init(src);
0196 }
0197 
0198 static inline bool sk_msg_full(const struct sk_msg *msg)
0199 {
0200     return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
0201 }
0202 
0203 static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
0204 {
0205     return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
0206 }
0207 
0208 static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
0209 {
0210     return &msg->sg.data[which];
0211 }
0212 
0213 static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
0214 {
0215     return msg->sg.data[which];
0216 }
0217 
0218 static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
0219 {
0220     return sg_page(sk_msg_elem(msg, which));
0221 }
0222 
0223 static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
0224 {
0225     return msg->flags & BPF_F_INGRESS;
0226 }
0227 
0228 static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
0229 {
0230     struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
0231 
0232     if (test_bit(msg->sg.start, msg->sg.copy)) {
0233         msg->data = NULL;
0234         msg->data_end = NULL;
0235     } else {
0236         msg->data = sg_virt(sge);
0237         msg->data_end = msg->data + sge->length;
0238     }
0239 }
0240 
0241 static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
0242                    u32 len, u32 offset)
0243 {
0244     struct scatterlist *sge;
0245 
0246     get_page(page);
0247     sge = sk_msg_elem(msg, msg->sg.end);
0248     sg_set_page(sge, page, len, offset);
0249     sg_unmark_end(sge);
0250 
0251     __set_bit(msg->sg.end, msg->sg.copy);
0252     msg->sg.size += len;
0253     sk_msg_iter_next(msg, end);
0254 }
0255 
0256 static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
0257 {
0258     do {
0259         if (copy_state)
0260             __set_bit(i, msg->sg.copy);
0261         else
0262             __clear_bit(i, msg->sg.copy);
0263         sk_msg_iter_var_next(i);
0264         if (i == msg->sg.end)
0265             break;
0266     } while (1);
0267 }
0268 
0269 static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
0270 {
0271     sk_msg_sg_copy(msg, start, true);
0272 }
0273 
0274 static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
0275 {
0276     sk_msg_sg_copy(msg, start, false);
0277 }
0278 
0279 static inline struct sk_psock *sk_psock(const struct sock *sk)
0280 {
0281     return __rcu_dereference_sk_user_data_with_flags(sk,
0282                              SK_USER_DATA_PSOCK);
0283 }
0284 
0285 static inline void sk_psock_set_state(struct sk_psock *psock,
0286                       enum sk_psock_state_bits bit)
0287 {
0288     set_bit(bit, &psock->state);
0289 }
0290 
0291 static inline void sk_psock_clear_state(struct sk_psock *psock,
0292                     enum sk_psock_state_bits bit)
0293 {
0294     clear_bit(bit, &psock->state);
0295 }
0296 
0297 static inline bool sk_psock_test_state(const struct sk_psock *psock,
0298                        enum sk_psock_state_bits bit)
0299 {
0300     return test_bit(bit, &psock->state);
0301 }
0302 
0303 static inline void sock_drop(struct sock *sk, struct sk_buff *skb)
0304 {
0305     sk_drops_add(sk, skb);
0306     kfree_skb(skb);
0307 }
0308 
0309 static inline void sk_psock_queue_msg(struct sk_psock *psock,
0310                       struct sk_msg *msg)
0311 {
0312     spin_lock_bh(&psock->ingress_lock);
0313     if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
0314         list_add_tail(&msg->list, &psock->ingress_msg);
0315     else {
0316         sk_msg_free(psock->sk, msg);
0317         kfree(msg);
0318     }
0319     spin_unlock_bh(&psock->ingress_lock);
0320 }
0321 
0322 static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
0323 {
0324     struct sk_msg *msg;
0325 
0326     spin_lock_bh(&psock->ingress_lock);
0327     msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
0328     if (msg)
0329         list_del(&msg->list);
0330     spin_unlock_bh(&psock->ingress_lock);
0331     return msg;
0332 }
0333 
0334 static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
0335 {
0336     struct sk_msg *msg;
0337 
0338     spin_lock_bh(&psock->ingress_lock);
0339     msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
0340     spin_unlock_bh(&psock->ingress_lock);
0341     return msg;
0342 }
0343 
0344 static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
0345                            struct sk_msg *msg)
0346 {
0347     struct sk_msg *ret;
0348 
0349     spin_lock_bh(&psock->ingress_lock);
0350     if (list_is_last(&msg->list, &psock->ingress_msg))
0351         ret = NULL;
0352     else
0353         ret = list_next_entry(msg, list);
0354     spin_unlock_bh(&psock->ingress_lock);
0355     return ret;
0356 }
0357 
0358 static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
0359 {
0360     return psock ? list_empty(&psock->ingress_msg) : true;
0361 }
0362 
0363 static inline void kfree_sk_msg(struct sk_msg *msg)
0364 {
0365     if (msg->skb)
0366         consume_skb(msg->skb);
0367     kfree(msg);
0368 }
0369 
0370 static inline void sk_psock_report_error(struct sk_psock *psock, int err)
0371 {
0372     struct sock *sk = psock->sk;
0373 
0374     sk->sk_err = err;
0375     sk_error_report(sk);
0376 }
0377 
0378 struct sk_psock *sk_psock_init(struct sock *sk, int node);
0379 void sk_psock_stop(struct sk_psock *psock, bool wait);
0380 
0381 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
0382 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
0383 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
0384 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
0385 #else
0386 static inline int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
0387 {
0388     return -EOPNOTSUPP;
0389 }
0390 
0391 static inline void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
0392 {
0393 }
0394 
0395 static inline void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
0396 {
0397 }
0398 #endif
0399 
0400 void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
0401 void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
0402 
0403 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
0404              struct sk_msg *msg);
0405 
0406 static inline struct sk_psock_link *sk_psock_init_link(void)
0407 {
0408     return kzalloc(sizeof(struct sk_psock_link),
0409                GFP_ATOMIC | __GFP_NOWARN);
0410 }
0411 
0412 static inline void sk_psock_free_link(struct sk_psock_link *link)
0413 {
0414     kfree(link);
0415 }
0416 
0417 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
0418 
0419 static inline void sk_psock_cork_free(struct sk_psock *psock)
0420 {
0421     if (psock->cork) {
0422         sk_msg_free(psock->sk, psock->cork);
0423         kfree(psock->cork);
0424         psock->cork = NULL;
0425     }
0426 }
0427 
0428 static inline void sk_psock_restore_proto(struct sock *sk,
0429                       struct sk_psock *psock)
0430 {
0431     if (psock->psock_update_sk_prot)
0432         psock->psock_update_sk_prot(sk, psock, true);
0433 }
0434 
0435 static inline struct sk_psock *sk_psock_get(struct sock *sk)
0436 {
0437     struct sk_psock *psock;
0438 
0439     rcu_read_lock();
0440     psock = sk_psock(sk);
0441     if (psock && !refcount_inc_not_zero(&psock->refcnt))
0442         psock = NULL;
0443     rcu_read_unlock();
0444     return psock;
0445 }
0446 
0447 void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
0448 
0449 static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
0450 {
0451     if (refcount_dec_and_test(&psock->refcnt))
0452         sk_psock_drop(sk, psock);
0453 }
0454 
0455 static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
0456 {
0457     if (psock->saved_data_ready)
0458         psock->saved_data_ready(sk);
0459     else
0460         sk->sk_data_ready(sk);
0461 }
0462 
0463 static inline void psock_set_prog(struct bpf_prog **pprog,
0464                   struct bpf_prog *prog)
0465 {
0466     prog = xchg(pprog, prog);
0467     if (prog)
0468         bpf_prog_put(prog);
0469 }
0470 
0471 static inline int psock_replace_prog(struct bpf_prog **pprog,
0472                      struct bpf_prog *prog,
0473                      struct bpf_prog *old)
0474 {
0475     if (cmpxchg(pprog, old, prog) != old)
0476         return -ENOENT;
0477 
0478     if (old)
0479         bpf_prog_put(old);
0480 
0481     return 0;
0482 }
0483 
0484 static inline void psock_progs_drop(struct sk_psock_progs *progs)
0485 {
0486     psock_set_prog(&progs->msg_parser, NULL);
0487     psock_set_prog(&progs->stream_parser, NULL);
0488     psock_set_prog(&progs->stream_verdict, NULL);
0489     psock_set_prog(&progs->skb_verdict, NULL);
0490 }
0491 
0492 int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
0493 
0494 static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
0495 {
0496     if (!psock)
0497         return false;
0498     return !!psock->saved_data_ready;
0499 }
0500 
0501 static inline bool sk_is_udp(const struct sock *sk)
0502 {
0503     return sk->sk_type == SOCK_DGRAM &&
0504            sk->sk_protocol == IPPROTO_UDP;
0505 }
0506 
0507 #if IS_ENABLED(CONFIG_NET_SOCK_MSG)
0508 
0509 #define BPF_F_STRPARSER (1UL << 1)
0510 
0511 /* We only have two bits so far. */
0512 #define BPF_F_PTR_MASK ~(BPF_F_INGRESS | BPF_F_STRPARSER)
0513 
0514 static inline bool skb_bpf_strparser(const struct sk_buff *skb)
0515 {
0516     unsigned long sk_redir = skb->_sk_redir;
0517 
0518     return sk_redir & BPF_F_STRPARSER;
0519 }
0520 
0521 static inline void skb_bpf_set_strparser(struct sk_buff *skb)
0522 {
0523     skb->_sk_redir |= BPF_F_STRPARSER;
0524 }
0525 
0526 static inline bool skb_bpf_ingress(const struct sk_buff *skb)
0527 {
0528     unsigned long sk_redir = skb->_sk_redir;
0529 
0530     return sk_redir & BPF_F_INGRESS;
0531 }
0532 
0533 static inline void skb_bpf_set_ingress(struct sk_buff *skb)
0534 {
0535     skb->_sk_redir |= BPF_F_INGRESS;
0536 }
0537 
0538 static inline void skb_bpf_set_redir(struct sk_buff *skb, struct sock *sk_redir,
0539                      bool ingress)
0540 {
0541     skb->_sk_redir = (unsigned long)sk_redir;
0542     if (ingress)
0543         skb->_sk_redir |= BPF_F_INGRESS;
0544 }
0545 
0546 static inline struct sock *skb_bpf_redirect_fetch(const struct sk_buff *skb)
0547 {
0548     unsigned long sk_redir = skb->_sk_redir;
0549 
0550     return (struct sock *)(sk_redir & BPF_F_PTR_MASK);
0551 }
0552 
0553 static inline void skb_bpf_redirect_clear(struct sk_buff *skb)
0554 {
0555     skb->_sk_redir = 0;
0556 }
0557 #endif /* CONFIG_NET_SOCK_MSG */
0558 #endif /* _LINUX_SKMSG_H */