0001
0002
0003
0004 #include <linux/skmsg.h>
0005 #include <linux/filter.h>
0006 #include <linux/bpf.h>
0007 #include <linux/init.h>
0008 #include <linux/wait.h>
0009
0010 #include <net/inet_common.h>
0011 #include <net/tls.h>
0012
0013 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
0014 struct sk_msg *msg, u32 apply_bytes, int flags)
0015 {
0016 bool apply = apply_bytes;
0017 struct scatterlist *sge;
0018 u32 size, copied = 0;
0019 struct sk_msg *tmp;
0020 int i, ret = 0;
0021
0022 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
0023 if (unlikely(!tmp))
0024 return -ENOMEM;
0025
0026 lock_sock(sk);
0027 tmp->sg.start = msg->sg.start;
0028 i = msg->sg.start;
0029 do {
0030 sge = sk_msg_elem(msg, i);
0031 size = (apply && apply_bytes < sge->length) ?
0032 apply_bytes : sge->length;
0033 if (!sk_wmem_schedule(sk, size)) {
0034 if (!copied)
0035 ret = -ENOMEM;
0036 break;
0037 }
0038
0039 sk_mem_charge(sk, size);
0040 sk_msg_xfer(tmp, msg, i, size);
0041 copied += size;
0042 if (sge->length)
0043 get_page(sk_msg_page(tmp, i));
0044 sk_msg_iter_var_next(i);
0045 tmp->sg.end = i;
0046 if (apply) {
0047 apply_bytes -= size;
0048 if (!apply_bytes)
0049 break;
0050 }
0051 } while (i != msg->sg.end);
0052
0053 if (!ret) {
0054 msg->sg.start = i;
0055 sk_psock_queue_msg(psock, tmp);
0056 sk_psock_data_ready(sk, psock);
0057 } else {
0058 sk_msg_free(sk, tmp);
0059 kfree(tmp);
0060 }
0061
0062 release_sock(sk);
0063 return ret;
0064 }
0065
0066 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
0067 int flags, bool uncharge)
0068 {
0069 bool apply = apply_bytes;
0070 struct scatterlist *sge;
0071 struct page *page;
0072 int size, ret = 0;
0073 u32 off;
0074
0075 while (1) {
0076 bool has_tx_ulp;
0077
0078 sge = sk_msg_elem(msg, msg->sg.start);
0079 size = (apply && apply_bytes < sge->length) ?
0080 apply_bytes : sge->length;
0081 off = sge->offset;
0082 page = sg_page(sge);
0083
0084 tcp_rate_check_app_limited(sk);
0085 retry:
0086 has_tx_ulp = tls_sw_has_ctx_tx(sk);
0087 if (has_tx_ulp) {
0088 flags |= MSG_SENDPAGE_NOPOLICY;
0089 ret = kernel_sendpage_locked(sk,
0090 page, off, size, flags);
0091 } else {
0092 ret = do_tcp_sendpages(sk, page, off, size, flags);
0093 }
0094
0095 if (ret <= 0)
0096 return ret;
0097 if (apply)
0098 apply_bytes -= ret;
0099 msg->sg.size -= ret;
0100 sge->offset += ret;
0101 sge->length -= ret;
0102 if (uncharge)
0103 sk_mem_uncharge(sk, ret);
0104 if (ret != size) {
0105 size -= ret;
0106 off += ret;
0107 goto retry;
0108 }
0109 if (!sge->length) {
0110 put_page(page);
0111 sk_msg_iter_next(msg, start);
0112 sg_init_table(sge, 1);
0113 if (msg->sg.start == msg->sg.end)
0114 break;
0115 }
0116 if (apply && !apply_bytes)
0117 break;
0118 }
0119
0120 return 0;
0121 }
0122
0123 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
0124 u32 apply_bytes, int flags, bool uncharge)
0125 {
0126 int ret;
0127
0128 lock_sock(sk);
0129 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
0130 release_sock(sk);
0131 return ret;
0132 }
0133
0134 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
0135 u32 bytes, int flags)
0136 {
0137 bool ingress = sk_msg_to_ingress(msg);
0138 struct sk_psock *psock = sk_psock_get(sk);
0139 int ret;
0140
0141 if (unlikely(!psock))
0142 return -EPIPE;
0143
0144 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
0145 tcp_bpf_push_locked(sk, msg, bytes, flags, false);
0146 sk_psock_put(sk, psock);
0147 return ret;
0148 }
0149 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
0150
0151 #ifdef CONFIG_BPF_SYSCALL
0152 static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
0153 long timeo)
0154 {
0155 DEFINE_WAIT_FUNC(wait, woken_wake_function);
0156 int ret = 0;
0157
0158 if (sk->sk_shutdown & RCV_SHUTDOWN)
0159 return 1;
0160
0161 if (!timeo)
0162 return ret;
0163
0164 add_wait_queue(sk_sleep(sk), &wait);
0165 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
0166 ret = sk_wait_event(sk, &timeo,
0167 !list_empty(&psock->ingress_msg) ||
0168 !skb_queue_empty(&sk->sk_receive_queue), &wait);
0169 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
0170 remove_wait_queue(sk_sleep(sk), &wait);
0171 return ret;
0172 }
0173
0174 static int tcp_bpf_recvmsg_parser(struct sock *sk,
0175 struct msghdr *msg,
0176 size_t len,
0177 int flags,
0178 int *addr_len)
0179 {
0180 struct sk_psock *psock;
0181 int copied;
0182
0183 if (unlikely(flags & MSG_ERRQUEUE))
0184 return inet_recv_error(sk, msg, len, addr_len);
0185
0186 psock = sk_psock_get(sk);
0187 if (unlikely(!psock))
0188 return tcp_recvmsg(sk, msg, len, flags, addr_len);
0189
0190 lock_sock(sk);
0191 msg_bytes_ready:
0192 copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
0193 if (!copied) {
0194 long timeo;
0195 int data;
0196
0197 if (sock_flag(sk, SOCK_DONE))
0198 goto out;
0199
0200 if (sk->sk_err) {
0201 copied = sock_error(sk);
0202 goto out;
0203 }
0204
0205 if (sk->sk_shutdown & RCV_SHUTDOWN)
0206 goto out;
0207
0208 if (sk->sk_state == TCP_CLOSE) {
0209 copied = -ENOTCONN;
0210 goto out;
0211 }
0212
0213 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
0214 if (!timeo) {
0215 copied = -EAGAIN;
0216 goto out;
0217 }
0218
0219 if (signal_pending(current)) {
0220 copied = sock_intr_errno(timeo);
0221 goto out;
0222 }
0223
0224 data = tcp_msg_wait_data(sk, psock, timeo);
0225 if (data && !sk_psock_queue_empty(psock))
0226 goto msg_bytes_ready;
0227 copied = -EAGAIN;
0228 }
0229 out:
0230 release_sock(sk);
0231 sk_psock_put(sk, psock);
0232 return copied;
0233 }
0234
0235 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
0236 int flags, int *addr_len)
0237 {
0238 struct sk_psock *psock;
0239 int copied, ret;
0240
0241 if (unlikely(flags & MSG_ERRQUEUE))
0242 return inet_recv_error(sk, msg, len, addr_len);
0243
0244 psock = sk_psock_get(sk);
0245 if (unlikely(!psock))
0246 return tcp_recvmsg(sk, msg, len, flags, addr_len);
0247 if (!skb_queue_empty(&sk->sk_receive_queue) &&
0248 sk_psock_queue_empty(psock)) {
0249 sk_psock_put(sk, psock);
0250 return tcp_recvmsg(sk, msg, len, flags, addr_len);
0251 }
0252 lock_sock(sk);
0253 msg_bytes_ready:
0254 copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
0255 if (!copied) {
0256 long timeo;
0257 int data;
0258
0259 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
0260 data = tcp_msg_wait_data(sk, psock, timeo);
0261 if (data) {
0262 if (!sk_psock_queue_empty(psock))
0263 goto msg_bytes_ready;
0264 release_sock(sk);
0265 sk_psock_put(sk, psock);
0266 return tcp_recvmsg(sk, msg, len, flags, addr_len);
0267 }
0268 copied = -EAGAIN;
0269 }
0270 ret = copied;
0271 release_sock(sk);
0272 sk_psock_put(sk, psock);
0273 return ret;
0274 }
0275
0276 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
0277 struct sk_msg *msg, int *copied, int flags)
0278 {
0279 bool cork = false, enospc = sk_msg_full(msg);
0280 struct sock *sk_redir;
0281 u32 tosend, delta = 0;
0282 u32 eval = __SK_NONE;
0283 int ret;
0284
0285 more_data:
0286 if (psock->eval == __SK_NONE) {
0287
0288
0289
0290
0291
0292 delta = msg->sg.size;
0293 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
0294 delta -= msg->sg.size;
0295 }
0296
0297 if (msg->cork_bytes &&
0298 msg->cork_bytes > msg->sg.size && !enospc) {
0299 psock->cork_bytes = msg->cork_bytes - msg->sg.size;
0300 if (!psock->cork) {
0301 psock->cork = kzalloc(sizeof(*psock->cork),
0302 GFP_ATOMIC | __GFP_NOWARN);
0303 if (!psock->cork)
0304 return -ENOMEM;
0305 }
0306 memcpy(psock->cork, msg, sizeof(*msg));
0307 return 0;
0308 }
0309
0310 tosend = msg->sg.size;
0311 if (psock->apply_bytes && psock->apply_bytes < tosend)
0312 tosend = psock->apply_bytes;
0313
0314 switch (psock->eval) {
0315 case __SK_PASS:
0316 ret = tcp_bpf_push(sk, msg, tosend, flags, true);
0317 if (unlikely(ret)) {
0318 *copied -= sk_msg_free(sk, msg);
0319 break;
0320 }
0321 sk_msg_apply_bytes(psock, tosend);
0322 break;
0323 case __SK_REDIRECT:
0324 sk_redir = psock->sk_redir;
0325 sk_msg_apply_bytes(psock, tosend);
0326 if (!psock->apply_bytes) {
0327
0328 eval = psock->eval;
0329 psock->eval = __SK_NONE;
0330 psock->sk_redir = NULL;
0331 }
0332 if (psock->cork) {
0333 cork = true;
0334 psock->cork = NULL;
0335 }
0336 sk_msg_return(sk, msg, msg->sg.size);
0337 release_sock(sk);
0338
0339 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
0340
0341 if (eval == __SK_REDIRECT)
0342 sock_put(sk_redir);
0343
0344 lock_sock(sk);
0345 if (unlikely(ret < 0)) {
0346 int free = sk_msg_free_nocharge(sk, msg);
0347
0348 if (!cork)
0349 *copied -= free;
0350 }
0351 if (cork) {
0352 sk_msg_free(sk, msg);
0353 kfree(msg);
0354 msg = NULL;
0355 ret = 0;
0356 }
0357 break;
0358 case __SK_DROP:
0359 default:
0360 sk_msg_free_partial(sk, msg, tosend);
0361 sk_msg_apply_bytes(psock, tosend);
0362 *copied -= (tosend + delta);
0363 return -EACCES;
0364 }
0365
0366 if (likely(!ret)) {
0367 if (!psock->apply_bytes) {
0368 psock->eval = __SK_NONE;
0369 if (psock->sk_redir) {
0370 sock_put(psock->sk_redir);
0371 psock->sk_redir = NULL;
0372 }
0373 }
0374 if (msg &&
0375 msg->sg.data[msg->sg.start].page_link &&
0376 msg->sg.data[msg->sg.start].length) {
0377 if (eval == __SK_REDIRECT)
0378 sk_mem_charge(sk, msg->sg.size);
0379 goto more_data;
0380 }
0381 }
0382 return ret;
0383 }
0384
0385 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
0386 {
0387 struct sk_msg tmp, *msg_tx = NULL;
0388 int copied = 0, err = 0;
0389 struct sk_psock *psock;
0390 long timeo;
0391 int flags;
0392
0393
0394 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
0395 flags |= MSG_NO_SHARED_FRAGS;
0396
0397 psock = sk_psock_get(sk);
0398 if (unlikely(!psock))
0399 return tcp_sendmsg(sk, msg, size);
0400
0401 lock_sock(sk);
0402 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
0403 while (msg_data_left(msg)) {
0404 bool enospc = false;
0405 u32 copy, osize;
0406
0407 if (sk->sk_err) {
0408 err = -sk->sk_err;
0409 goto out_err;
0410 }
0411
0412 copy = msg_data_left(msg);
0413 if (!sk_stream_memory_free(sk))
0414 goto wait_for_sndbuf;
0415 if (psock->cork) {
0416 msg_tx = psock->cork;
0417 } else {
0418 msg_tx = &tmp;
0419 sk_msg_init(msg_tx);
0420 }
0421
0422 osize = msg_tx->sg.size;
0423 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
0424 if (err) {
0425 if (err != -ENOSPC)
0426 goto wait_for_memory;
0427 enospc = true;
0428 copy = msg_tx->sg.size - osize;
0429 }
0430
0431 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
0432 copy);
0433 if (err < 0) {
0434 sk_msg_trim(sk, msg_tx, osize);
0435 goto out_err;
0436 }
0437
0438 copied += copy;
0439 if (psock->cork_bytes) {
0440 if (size > psock->cork_bytes)
0441 psock->cork_bytes = 0;
0442 else
0443 psock->cork_bytes -= size;
0444 if (psock->cork_bytes && !enospc)
0445 goto out_err;
0446
0447 psock->eval = __SK_NONE;
0448 psock->cork_bytes = 0;
0449 }
0450
0451 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
0452 if (unlikely(err < 0))
0453 goto out_err;
0454 continue;
0455 wait_for_sndbuf:
0456 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
0457 wait_for_memory:
0458 err = sk_stream_wait_memory(sk, &timeo);
0459 if (err) {
0460 if (msg_tx && msg_tx != psock->cork)
0461 sk_msg_free(sk, msg_tx);
0462 goto out_err;
0463 }
0464 }
0465 out_err:
0466 if (err < 0)
0467 err = sk_stream_error(sk, msg->msg_flags, err);
0468 release_sock(sk);
0469 sk_psock_put(sk, psock);
0470 return copied ? copied : err;
0471 }
0472
0473 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
0474 size_t size, int flags)
0475 {
0476 struct sk_msg tmp, *msg = NULL;
0477 int err = 0, copied = 0;
0478 struct sk_psock *psock;
0479 bool enospc = false;
0480
0481 psock = sk_psock_get(sk);
0482 if (unlikely(!psock))
0483 return tcp_sendpage(sk, page, offset, size, flags);
0484
0485 lock_sock(sk);
0486 if (psock->cork) {
0487 msg = psock->cork;
0488 } else {
0489 msg = &tmp;
0490 sk_msg_init(msg);
0491 }
0492
0493
0494 if (unlikely(sk_msg_full(msg)))
0495 goto out_err;
0496
0497 sk_msg_page_add(msg, page, size, offset);
0498 sk_mem_charge(sk, size);
0499 copied = size;
0500 if (sk_msg_full(msg))
0501 enospc = true;
0502 if (psock->cork_bytes) {
0503 if (size > psock->cork_bytes)
0504 psock->cork_bytes = 0;
0505 else
0506 psock->cork_bytes -= size;
0507 if (psock->cork_bytes && !enospc)
0508 goto out_err;
0509
0510 psock->eval = __SK_NONE;
0511 psock->cork_bytes = 0;
0512 }
0513
0514 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
0515 out_err:
0516 release_sock(sk);
0517 sk_psock_put(sk, psock);
0518 return copied ? copied : err;
0519 }
0520
0521 enum {
0522 TCP_BPF_IPV4,
0523 TCP_BPF_IPV6,
0524 TCP_BPF_NUM_PROTS,
0525 };
0526
0527 enum {
0528 TCP_BPF_BASE,
0529 TCP_BPF_TX,
0530 TCP_BPF_RX,
0531 TCP_BPF_TXRX,
0532 TCP_BPF_NUM_CFGS,
0533 };
0534
0535 static struct proto *tcpv6_prot_saved __read_mostly;
0536 static DEFINE_SPINLOCK(tcpv6_prot_lock);
0537 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
0538
0539 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
0540 struct proto *base)
0541 {
0542 prot[TCP_BPF_BASE] = *base;
0543 prot[TCP_BPF_BASE].destroy = sock_map_destroy;
0544 prot[TCP_BPF_BASE].close = sock_map_close;
0545 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
0546 prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable;
0547
0548 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
0549 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
0550 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
0551
0552 prot[TCP_BPF_RX] = prot[TCP_BPF_BASE];
0553 prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser;
0554
0555 prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX];
0556 prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser;
0557 }
0558
0559 static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
0560 {
0561 if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
0562 spin_lock_bh(&tcpv6_prot_lock);
0563 if (likely(ops != tcpv6_prot_saved)) {
0564 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
0565 smp_store_release(&tcpv6_prot_saved, ops);
0566 }
0567 spin_unlock_bh(&tcpv6_prot_lock);
0568 }
0569 }
0570
0571 static int __init tcp_bpf_v4_build_proto(void)
0572 {
0573 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
0574 return 0;
0575 }
0576 late_initcall(tcp_bpf_v4_build_proto);
0577
0578 static int tcp_bpf_assert_proto_ops(struct proto *ops)
0579 {
0580
0581
0582
0583
0584 return ops->recvmsg == tcp_recvmsg &&
0585 ops->sendmsg == tcp_sendmsg &&
0586 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
0587 }
0588
0589 int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
0590 {
0591 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
0592 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
0593
0594 if (psock->progs.stream_verdict || psock->progs.skb_verdict) {
0595 config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX;
0596 }
0597
0598 if (restore) {
0599 if (inet_csk_has_ulp(sk)) {
0600
0601
0602
0603
0604
0605 WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash);
0606 tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
0607 } else {
0608 sk->sk_write_space = psock->saved_write_space;
0609
0610 WRITE_ONCE(sk->sk_prot, psock->sk_proto);
0611 }
0612 return 0;
0613 }
0614
0615 if (sk->sk_family == AF_INET6) {
0616 if (tcp_bpf_assert_proto_ops(psock->sk_proto))
0617 return -EINVAL;
0618
0619 tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
0620 }
0621
0622
0623 WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
0624 return 0;
0625 }
0626 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
0627
0628
0629
0630
0631
0632
0633 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
0634 {
0635 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
0636 struct proto *prot = newsk->sk_prot;
0637
0638 if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
0639 newsk->sk_prot = sk->sk_prot_creator;
0640 }
0641 #endif