0001
0002
0003
0004 #ifndef _LINUX_SKMSG_H
0005 #define _LINUX_SKMSG_H
0006
0007 #include <linux/bpf.h>
0008 #include <linux/filter.h>
0009 #include <linux/scatterlist.h>
0010 #include <linux/skbuff.h>
0011
0012 #include <net/sock.h>
0013 #include <net/tcp.h>
0014 #include <net/strparser.h>
0015
0016 #define MAX_MSG_FRAGS MAX_SKB_FRAGS
0017 #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
0018
0019 enum __sk_action {
0020 __SK_DROP = 0,
0021 __SK_PASS,
0022 __SK_REDIRECT,
0023 __SK_NONE,
0024 };
0025
0026 struct sk_msg_sg {
0027 u32 start;
0028 u32 curr;
0029 u32 end;
0030 u32 size;
0031 u32 copybreak;
0032 DECLARE_BITMAP(copy, MAX_MSG_FRAGS + 2);
0033
0034
0035
0036
0037
0038
0039 struct scatterlist data[MAX_MSG_FRAGS + 2];
0040 };
0041
0042
0043 struct sk_msg {
0044 struct sk_msg_sg sg;
0045 void *data;
0046 void *data_end;
0047 u32 apply_bytes;
0048 u32 cork_bytes;
0049 u32 flags;
0050 struct sk_buff *skb;
0051 struct sock *sk_redir;
0052 struct sock *sk;
0053 struct list_head list;
0054 };
0055
0056 struct sk_psock_progs {
0057 struct bpf_prog *msg_parser;
0058 struct bpf_prog *stream_parser;
0059 struct bpf_prog *stream_verdict;
0060 struct bpf_prog *skb_verdict;
0061 };
0062
0063 enum sk_psock_state_bits {
0064 SK_PSOCK_TX_ENABLED,
0065 };
0066
0067 struct sk_psock_link {
0068 struct list_head list;
0069 struct bpf_map *map;
0070 void *link_raw;
0071 };
0072
0073 struct sk_psock_work_state {
0074 struct sk_buff *skb;
0075 u32 len;
0076 u32 off;
0077 };
0078
0079 struct sk_psock {
0080 struct sock *sk;
0081 struct sock *sk_redir;
0082 u32 apply_bytes;
0083 u32 cork_bytes;
0084 u32 eval;
0085 struct sk_msg *cork;
0086 struct sk_psock_progs progs;
0087 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
0088 struct strparser strp;
0089 #endif
0090 struct sk_buff_head ingress_skb;
0091 struct list_head ingress_msg;
0092 spinlock_t ingress_lock;
0093 unsigned long state;
0094 struct list_head link;
0095 spinlock_t link_lock;
0096 refcount_t refcnt;
0097 void (*saved_unhash)(struct sock *sk);
0098 void (*saved_destroy)(struct sock *sk);
0099 void (*saved_close)(struct sock *sk, long timeout);
0100 void (*saved_write_space)(struct sock *sk);
0101 void (*saved_data_ready)(struct sock *sk);
0102 int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
0103 bool restore);
0104 struct proto *sk_proto;
0105 struct mutex work_mutex;
0106 struct sk_psock_work_state work_state;
0107 struct work_struct work;
0108 struct rcu_work rwork;
0109 };
0110
0111 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
0112 int elem_first_coalesce);
0113 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
0114 u32 off, u32 len);
0115 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
0116 int sk_msg_free(struct sock *sk, struct sk_msg *msg);
0117 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
0118 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
0119 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
0120 u32 bytes);
0121
0122 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
0123 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
0124
0125 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
0126 struct sk_msg *msg, u32 bytes);
0127 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
0128 struct sk_msg *msg, u32 bytes);
0129 int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
0130 int len, int flags);
0131 bool sk_msg_is_readable(struct sock *sk);
0132
0133 static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
0134 {
0135 WARN_ON(i == msg->sg.end && bytes);
0136 }
0137
0138 static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
0139 {
0140 if (psock->apply_bytes) {
0141 if (psock->apply_bytes < bytes)
0142 psock->apply_bytes = 0;
0143 else
0144 psock->apply_bytes -= bytes;
0145 }
0146 }
0147
0148 static inline u32 sk_msg_iter_dist(u32 start, u32 end)
0149 {
0150 return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
0151 }
0152
0153 #define sk_msg_iter_var_prev(var) \
0154 do { \
0155 if (var == 0) \
0156 var = NR_MSG_FRAG_IDS - 1; \
0157 else \
0158 var--; \
0159 } while (0)
0160
0161 #define sk_msg_iter_var_next(var) \
0162 do { \
0163 var++; \
0164 if (var == NR_MSG_FRAG_IDS) \
0165 var = 0; \
0166 } while (0)
0167
0168 #define sk_msg_iter_prev(msg, which) \
0169 sk_msg_iter_var_prev(msg->sg.which)
0170
0171 #define sk_msg_iter_next(msg, which) \
0172 sk_msg_iter_var_next(msg->sg.which)
0173
0174 static inline void sk_msg_init(struct sk_msg *msg)
0175 {
0176 BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
0177 memset(msg, 0, sizeof(*msg));
0178 sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
0179 }
0180
0181 static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
0182 int which, u32 size)
0183 {
0184 dst->sg.data[which] = src->sg.data[which];
0185 dst->sg.data[which].length = size;
0186 dst->sg.size += size;
0187 src->sg.size -= size;
0188 src->sg.data[which].length -= size;
0189 src->sg.data[which].offset += size;
0190 }
0191
0192 static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
0193 {
0194 memcpy(dst, src, sizeof(*src));
0195 sk_msg_init(src);
0196 }
0197
0198 static inline bool sk_msg_full(const struct sk_msg *msg)
0199 {
0200 return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
0201 }
0202
0203 static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
0204 {
0205 return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
0206 }
0207
0208 static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
0209 {
0210 return &msg->sg.data[which];
0211 }
0212
0213 static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
0214 {
0215 return msg->sg.data[which];
0216 }
0217
0218 static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
0219 {
0220 return sg_page(sk_msg_elem(msg, which));
0221 }
0222
0223 static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
0224 {
0225 return msg->flags & BPF_F_INGRESS;
0226 }
0227
0228 static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
0229 {
0230 struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
0231
0232 if (test_bit(msg->sg.start, msg->sg.copy)) {
0233 msg->data = NULL;
0234 msg->data_end = NULL;
0235 } else {
0236 msg->data = sg_virt(sge);
0237 msg->data_end = msg->data + sge->length;
0238 }
0239 }
0240
0241 static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
0242 u32 len, u32 offset)
0243 {
0244 struct scatterlist *sge;
0245
0246 get_page(page);
0247 sge = sk_msg_elem(msg, msg->sg.end);
0248 sg_set_page(sge, page, len, offset);
0249 sg_unmark_end(sge);
0250
0251 __set_bit(msg->sg.end, msg->sg.copy);
0252 msg->sg.size += len;
0253 sk_msg_iter_next(msg, end);
0254 }
0255
0256 static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
0257 {
0258 do {
0259 if (copy_state)
0260 __set_bit(i, msg->sg.copy);
0261 else
0262 __clear_bit(i, msg->sg.copy);
0263 sk_msg_iter_var_next(i);
0264 if (i == msg->sg.end)
0265 break;
0266 } while (1);
0267 }
0268
0269 static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
0270 {
0271 sk_msg_sg_copy(msg, start, true);
0272 }
0273
0274 static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
0275 {
0276 sk_msg_sg_copy(msg, start, false);
0277 }
0278
0279 static inline struct sk_psock *sk_psock(const struct sock *sk)
0280 {
0281 return __rcu_dereference_sk_user_data_with_flags(sk,
0282 SK_USER_DATA_PSOCK);
0283 }
0284
0285 static inline void sk_psock_set_state(struct sk_psock *psock,
0286 enum sk_psock_state_bits bit)
0287 {
0288 set_bit(bit, &psock->state);
0289 }
0290
0291 static inline void sk_psock_clear_state(struct sk_psock *psock,
0292 enum sk_psock_state_bits bit)
0293 {
0294 clear_bit(bit, &psock->state);
0295 }
0296
0297 static inline bool sk_psock_test_state(const struct sk_psock *psock,
0298 enum sk_psock_state_bits bit)
0299 {
0300 return test_bit(bit, &psock->state);
0301 }
0302
0303 static inline void sock_drop(struct sock *sk, struct sk_buff *skb)
0304 {
0305 sk_drops_add(sk, skb);
0306 kfree_skb(skb);
0307 }
0308
0309 static inline void sk_psock_queue_msg(struct sk_psock *psock,
0310 struct sk_msg *msg)
0311 {
0312 spin_lock_bh(&psock->ingress_lock);
0313 if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
0314 list_add_tail(&msg->list, &psock->ingress_msg);
0315 else {
0316 sk_msg_free(psock->sk, msg);
0317 kfree(msg);
0318 }
0319 spin_unlock_bh(&psock->ingress_lock);
0320 }
0321
0322 static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
0323 {
0324 struct sk_msg *msg;
0325
0326 spin_lock_bh(&psock->ingress_lock);
0327 msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
0328 if (msg)
0329 list_del(&msg->list);
0330 spin_unlock_bh(&psock->ingress_lock);
0331 return msg;
0332 }
0333
0334 static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
0335 {
0336 struct sk_msg *msg;
0337
0338 spin_lock_bh(&psock->ingress_lock);
0339 msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
0340 spin_unlock_bh(&psock->ingress_lock);
0341 return msg;
0342 }
0343
0344 static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
0345 struct sk_msg *msg)
0346 {
0347 struct sk_msg *ret;
0348
0349 spin_lock_bh(&psock->ingress_lock);
0350 if (list_is_last(&msg->list, &psock->ingress_msg))
0351 ret = NULL;
0352 else
0353 ret = list_next_entry(msg, list);
0354 spin_unlock_bh(&psock->ingress_lock);
0355 return ret;
0356 }
0357
0358 static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
0359 {
0360 return psock ? list_empty(&psock->ingress_msg) : true;
0361 }
0362
0363 static inline void kfree_sk_msg(struct sk_msg *msg)
0364 {
0365 if (msg->skb)
0366 consume_skb(msg->skb);
0367 kfree(msg);
0368 }
0369
0370 static inline void sk_psock_report_error(struct sk_psock *psock, int err)
0371 {
0372 struct sock *sk = psock->sk;
0373
0374 sk->sk_err = err;
0375 sk_error_report(sk);
0376 }
0377
0378 struct sk_psock *sk_psock_init(struct sock *sk, int node);
0379 void sk_psock_stop(struct sk_psock *psock, bool wait);
0380
0381 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
0382 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
0383 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
0384 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
0385 #else
0386 static inline int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
0387 {
0388 return -EOPNOTSUPP;
0389 }
0390
0391 static inline void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
0392 {
0393 }
0394
0395 static inline void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
0396 {
0397 }
0398 #endif
0399
0400 void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
0401 void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
0402
0403 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
0404 struct sk_msg *msg);
0405
0406 static inline struct sk_psock_link *sk_psock_init_link(void)
0407 {
0408 return kzalloc(sizeof(struct sk_psock_link),
0409 GFP_ATOMIC | __GFP_NOWARN);
0410 }
0411
0412 static inline void sk_psock_free_link(struct sk_psock_link *link)
0413 {
0414 kfree(link);
0415 }
0416
0417 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
0418
0419 static inline void sk_psock_cork_free(struct sk_psock *psock)
0420 {
0421 if (psock->cork) {
0422 sk_msg_free(psock->sk, psock->cork);
0423 kfree(psock->cork);
0424 psock->cork = NULL;
0425 }
0426 }
0427
0428 static inline void sk_psock_restore_proto(struct sock *sk,
0429 struct sk_psock *psock)
0430 {
0431 if (psock->psock_update_sk_prot)
0432 psock->psock_update_sk_prot(sk, psock, true);
0433 }
0434
0435 static inline struct sk_psock *sk_psock_get(struct sock *sk)
0436 {
0437 struct sk_psock *psock;
0438
0439 rcu_read_lock();
0440 psock = sk_psock(sk);
0441 if (psock && !refcount_inc_not_zero(&psock->refcnt))
0442 psock = NULL;
0443 rcu_read_unlock();
0444 return psock;
0445 }
0446
0447 void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
0448
0449 static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
0450 {
0451 if (refcount_dec_and_test(&psock->refcnt))
0452 sk_psock_drop(sk, psock);
0453 }
0454
0455 static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
0456 {
0457 if (psock->saved_data_ready)
0458 psock->saved_data_ready(sk);
0459 else
0460 sk->sk_data_ready(sk);
0461 }
0462
0463 static inline void psock_set_prog(struct bpf_prog **pprog,
0464 struct bpf_prog *prog)
0465 {
0466 prog = xchg(pprog, prog);
0467 if (prog)
0468 bpf_prog_put(prog);
0469 }
0470
0471 static inline int psock_replace_prog(struct bpf_prog **pprog,
0472 struct bpf_prog *prog,
0473 struct bpf_prog *old)
0474 {
0475 if (cmpxchg(pprog, old, prog) != old)
0476 return -ENOENT;
0477
0478 if (old)
0479 bpf_prog_put(old);
0480
0481 return 0;
0482 }
0483
0484 static inline void psock_progs_drop(struct sk_psock_progs *progs)
0485 {
0486 psock_set_prog(&progs->msg_parser, NULL);
0487 psock_set_prog(&progs->stream_parser, NULL);
0488 psock_set_prog(&progs->stream_verdict, NULL);
0489 psock_set_prog(&progs->skb_verdict, NULL);
0490 }
0491
0492 int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
0493
0494 static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
0495 {
0496 if (!psock)
0497 return false;
0498 return !!psock->saved_data_ready;
0499 }
0500
0501 static inline bool sk_is_udp(const struct sock *sk)
0502 {
0503 return sk->sk_type == SOCK_DGRAM &&
0504 sk->sk_protocol == IPPROTO_UDP;
0505 }
0506
0507 #if IS_ENABLED(CONFIG_NET_SOCK_MSG)
0508
0509 #define BPF_F_STRPARSER (1UL << 1)
0510
0511
0512 #define BPF_F_PTR_MASK ~(BPF_F_INGRESS | BPF_F_STRPARSER)
0513
0514 static inline bool skb_bpf_strparser(const struct sk_buff *skb)
0515 {
0516 unsigned long sk_redir = skb->_sk_redir;
0517
0518 return sk_redir & BPF_F_STRPARSER;
0519 }
0520
0521 static inline void skb_bpf_set_strparser(struct sk_buff *skb)
0522 {
0523 skb->_sk_redir |= BPF_F_STRPARSER;
0524 }
0525
0526 static inline bool skb_bpf_ingress(const struct sk_buff *skb)
0527 {
0528 unsigned long sk_redir = skb->_sk_redir;
0529
0530 return sk_redir & BPF_F_INGRESS;
0531 }
0532
0533 static inline void skb_bpf_set_ingress(struct sk_buff *skb)
0534 {
0535 skb->_sk_redir |= BPF_F_INGRESS;
0536 }
0537
0538 static inline void skb_bpf_set_redir(struct sk_buff *skb, struct sock *sk_redir,
0539 bool ingress)
0540 {
0541 skb->_sk_redir = (unsigned long)sk_redir;
0542 if (ingress)
0543 skb->_sk_redir |= BPF_F_INGRESS;
0544 }
0545
0546 static inline struct sock *skb_bpf_redirect_fetch(const struct sk_buff *skb)
0547 {
0548 unsigned long sk_redir = skb->_sk_redir;
0549
0550 return (struct sock *)(sk_redir & BPF_F_PTR_MASK);
0551 }
0552
0553 static inline void skb_bpf_redirect_clear(struct sk_buff *skb)
0554 {
0555 skb->_sk_redir = 0;
0556 }
0557 #endif
0558 #endif