Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
0003 
0004 #include <linux/skmsg.h>
0005 #include <linux/filter.h>
0006 #include <linux/bpf.h>
0007 #include <linux/init.h>
0008 #include <linux/wait.h>
0009 
0010 #include <net/inet_common.h>
0011 #include <net/tls.h>
0012 
0013 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
0014                struct sk_msg *msg, u32 apply_bytes, int flags)
0015 {
0016     bool apply = apply_bytes;
0017     struct scatterlist *sge;
0018     u32 size, copied = 0;
0019     struct sk_msg *tmp;
0020     int i, ret = 0;
0021 
0022     tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
0023     if (unlikely(!tmp))
0024         return -ENOMEM;
0025 
0026     lock_sock(sk);
0027     tmp->sg.start = msg->sg.start;
0028     i = msg->sg.start;
0029     do {
0030         sge = sk_msg_elem(msg, i);
0031         size = (apply && apply_bytes < sge->length) ?
0032             apply_bytes : sge->length;
0033         if (!sk_wmem_schedule(sk, size)) {
0034             if (!copied)
0035                 ret = -ENOMEM;
0036             break;
0037         }
0038 
0039         sk_mem_charge(sk, size);
0040         sk_msg_xfer(tmp, msg, i, size);
0041         copied += size;
0042         if (sge->length)
0043             get_page(sk_msg_page(tmp, i));
0044         sk_msg_iter_var_next(i);
0045         tmp->sg.end = i;
0046         if (apply) {
0047             apply_bytes -= size;
0048             if (!apply_bytes)
0049                 break;
0050         }
0051     } while (i != msg->sg.end);
0052 
0053     if (!ret) {
0054         msg->sg.start = i;
0055         sk_psock_queue_msg(psock, tmp);
0056         sk_psock_data_ready(sk, psock);
0057     } else {
0058         sk_msg_free(sk, tmp);
0059         kfree(tmp);
0060     }
0061 
0062     release_sock(sk);
0063     return ret;
0064 }
0065 
0066 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
0067             int flags, bool uncharge)
0068 {
0069     bool apply = apply_bytes;
0070     struct scatterlist *sge;
0071     struct page *page;
0072     int size, ret = 0;
0073     u32 off;
0074 
0075     while (1) {
0076         bool has_tx_ulp;
0077 
0078         sge = sk_msg_elem(msg, msg->sg.start);
0079         size = (apply && apply_bytes < sge->length) ?
0080             apply_bytes : sge->length;
0081         off  = sge->offset;
0082         page = sg_page(sge);
0083 
0084         tcp_rate_check_app_limited(sk);
0085 retry:
0086         has_tx_ulp = tls_sw_has_ctx_tx(sk);
0087         if (has_tx_ulp) {
0088             flags |= MSG_SENDPAGE_NOPOLICY;
0089             ret = kernel_sendpage_locked(sk,
0090                              page, off, size, flags);
0091         } else {
0092             ret = do_tcp_sendpages(sk, page, off, size, flags);
0093         }
0094 
0095         if (ret <= 0)
0096             return ret;
0097         if (apply)
0098             apply_bytes -= ret;
0099         msg->sg.size -= ret;
0100         sge->offset += ret;
0101         sge->length -= ret;
0102         if (uncharge)
0103             sk_mem_uncharge(sk, ret);
0104         if (ret != size) {
0105             size -= ret;
0106             off  += ret;
0107             goto retry;
0108         }
0109         if (!sge->length) {
0110             put_page(page);
0111             sk_msg_iter_next(msg, start);
0112             sg_init_table(sge, 1);
0113             if (msg->sg.start == msg->sg.end)
0114                 break;
0115         }
0116         if (apply && !apply_bytes)
0117             break;
0118     }
0119 
0120     return 0;
0121 }
0122 
0123 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
0124                    u32 apply_bytes, int flags, bool uncharge)
0125 {
0126     int ret;
0127 
0128     lock_sock(sk);
0129     ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
0130     release_sock(sk);
0131     return ret;
0132 }
0133 
0134 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
0135               u32 bytes, int flags)
0136 {
0137     bool ingress = sk_msg_to_ingress(msg);
0138     struct sk_psock *psock = sk_psock_get(sk);
0139     int ret;
0140 
0141     if (unlikely(!psock))
0142         return -EPIPE;
0143 
0144     ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
0145             tcp_bpf_push_locked(sk, msg, bytes, flags, false);
0146     sk_psock_put(sk, psock);
0147     return ret;
0148 }
0149 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
0150 
0151 #ifdef CONFIG_BPF_SYSCALL
0152 static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
0153                  long timeo)
0154 {
0155     DEFINE_WAIT_FUNC(wait, woken_wake_function);
0156     int ret = 0;
0157 
0158     if (sk->sk_shutdown & RCV_SHUTDOWN)
0159         return 1;
0160 
0161     if (!timeo)
0162         return ret;
0163 
0164     add_wait_queue(sk_sleep(sk), &wait);
0165     sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
0166     ret = sk_wait_event(sk, &timeo,
0167                 !list_empty(&psock->ingress_msg) ||
0168                 !skb_queue_empty(&sk->sk_receive_queue), &wait);
0169     sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
0170     remove_wait_queue(sk_sleep(sk), &wait);
0171     return ret;
0172 }
0173 
0174 static int tcp_bpf_recvmsg_parser(struct sock *sk,
0175                   struct msghdr *msg,
0176                   size_t len,
0177                   int flags,
0178                   int *addr_len)
0179 {
0180     struct sk_psock *psock;
0181     int copied;
0182 
0183     if (unlikely(flags & MSG_ERRQUEUE))
0184         return inet_recv_error(sk, msg, len, addr_len);
0185 
0186     psock = sk_psock_get(sk);
0187     if (unlikely(!psock))
0188         return tcp_recvmsg(sk, msg, len, flags, addr_len);
0189 
0190     lock_sock(sk);
0191 msg_bytes_ready:
0192     copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
0193     if (!copied) {
0194         long timeo;
0195         int data;
0196 
0197         if (sock_flag(sk, SOCK_DONE))
0198             goto out;
0199 
0200         if (sk->sk_err) {
0201             copied = sock_error(sk);
0202             goto out;
0203         }
0204 
0205         if (sk->sk_shutdown & RCV_SHUTDOWN)
0206             goto out;
0207 
0208         if (sk->sk_state == TCP_CLOSE) {
0209             copied = -ENOTCONN;
0210             goto out;
0211         }
0212 
0213         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
0214         if (!timeo) {
0215             copied = -EAGAIN;
0216             goto out;
0217         }
0218 
0219         if (signal_pending(current)) {
0220             copied = sock_intr_errno(timeo);
0221             goto out;
0222         }
0223 
0224         data = tcp_msg_wait_data(sk, psock, timeo);
0225         if (data && !sk_psock_queue_empty(psock))
0226             goto msg_bytes_ready;
0227         copied = -EAGAIN;
0228     }
0229 out:
0230     release_sock(sk);
0231     sk_psock_put(sk, psock);
0232     return copied;
0233 }
0234 
0235 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
0236                int flags, int *addr_len)
0237 {
0238     struct sk_psock *psock;
0239     int copied, ret;
0240 
0241     if (unlikely(flags & MSG_ERRQUEUE))
0242         return inet_recv_error(sk, msg, len, addr_len);
0243 
0244     psock = sk_psock_get(sk);
0245     if (unlikely(!psock))
0246         return tcp_recvmsg(sk, msg, len, flags, addr_len);
0247     if (!skb_queue_empty(&sk->sk_receive_queue) &&
0248         sk_psock_queue_empty(psock)) {
0249         sk_psock_put(sk, psock);
0250         return tcp_recvmsg(sk, msg, len, flags, addr_len);
0251     }
0252     lock_sock(sk);
0253 msg_bytes_ready:
0254     copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
0255     if (!copied) {
0256         long timeo;
0257         int data;
0258 
0259         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
0260         data = tcp_msg_wait_data(sk, psock, timeo);
0261         if (data) {
0262             if (!sk_psock_queue_empty(psock))
0263                 goto msg_bytes_ready;
0264             release_sock(sk);
0265             sk_psock_put(sk, psock);
0266             return tcp_recvmsg(sk, msg, len, flags, addr_len);
0267         }
0268         copied = -EAGAIN;
0269     }
0270     ret = copied;
0271     release_sock(sk);
0272     sk_psock_put(sk, psock);
0273     return ret;
0274 }
0275 
0276 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
0277                 struct sk_msg *msg, int *copied, int flags)
0278 {
0279     bool cork = false, enospc = sk_msg_full(msg);
0280     struct sock *sk_redir;
0281     u32 tosend, delta = 0;
0282     u32 eval = __SK_NONE;
0283     int ret;
0284 
0285 more_data:
0286     if (psock->eval == __SK_NONE) {
0287         /* Track delta in msg size to add/subtract it on SK_DROP from
0288          * returned to user copied size. This ensures user doesn't
0289          * get a positive return code with msg_cut_data and SK_DROP
0290          * verdict.
0291          */
0292         delta = msg->sg.size;
0293         psock->eval = sk_psock_msg_verdict(sk, psock, msg);
0294         delta -= msg->sg.size;
0295     }
0296 
0297     if (msg->cork_bytes &&
0298         msg->cork_bytes > msg->sg.size && !enospc) {
0299         psock->cork_bytes = msg->cork_bytes - msg->sg.size;
0300         if (!psock->cork) {
0301             psock->cork = kzalloc(sizeof(*psock->cork),
0302                           GFP_ATOMIC | __GFP_NOWARN);
0303             if (!psock->cork)
0304                 return -ENOMEM;
0305         }
0306         memcpy(psock->cork, msg, sizeof(*msg));
0307         return 0;
0308     }
0309 
0310     tosend = msg->sg.size;
0311     if (psock->apply_bytes && psock->apply_bytes < tosend)
0312         tosend = psock->apply_bytes;
0313 
0314     switch (psock->eval) {
0315     case __SK_PASS:
0316         ret = tcp_bpf_push(sk, msg, tosend, flags, true);
0317         if (unlikely(ret)) {
0318             *copied -= sk_msg_free(sk, msg);
0319             break;
0320         }
0321         sk_msg_apply_bytes(psock, tosend);
0322         break;
0323     case __SK_REDIRECT:
0324         sk_redir = psock->sk_redir;
0325         sk_msg_apply_bytes(psock, tosend);
0326         if (!psock->apply_bytes) {
0327             /* Clean up before releasing the sock lock. */
0328             eval = psock->eval;
0329             psock->eval = __SK_NONE;
0330             psock->sk_redir = NULL;
0331         }
0332         if (psock->cork) {
0333             cork = true;
0334             psock->cork = NULL;
0335         }
0336         sk_msg_return(sk, msg, msg->sg.size);
0337         release_sock(sk);
0338 
0339         ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
0340 
0341         if (eval == __SK_REDIRECT)
0342             sock_put(sk_redir);
0343 
0344         lock_sock(sk);
0345         if (unlikely(ret < 0)) {
0346             int free = sk_msg_free_nocharge(sk, msg);
0347 
0348             if (!cork)
0349                 *copied -= free;
0350         }
0351         if (cork) {
0352             sk_msg_free(sk, msg);
0353             kfree(msg);
0354             msg = NULL;
0355             ret = 0;
0356         }
0357         break;
0358     case __SK_DROP:
0359     default:
0360         sk_msg_free_partial(sk, msg, tosend);
0361         sk_msg_apply_bytes(psock, tosend);
0362         *copied -= (tosend + delta);
0363         return -EACCES;
0364     }
0365 
0366     if (likely(!ret)) {
0367         if (!psock->apply_bytes) {
0368             psock->eval =  __SK_NONE;
0369             if (psock->sk_redir) {
0370                 sock_put(psock->sk_redir);
0371                 psock->sk_redir = NULL;
0372             }
0373         }
0374         if (msg &&
0375             msg->sg.data[msg->sg.start].page_link &&
0376             msg->sg.data[msg->sg.start].length) {
0377             if (eval == __SK_REDIRECT)
0378                 sk_mem_charge(sk, msg->sg.size);
0379             goto more_data;
0380         }
0381     }
0382     return ret;
0383 }
0384 
0385 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
0386 {
0387     struct sk_msg tmp, *msg_tx = NULL;
0388     int copied = 0, err = 0;
0389     struct sk_psock *psock;
0390     long timeo;
0391     int flags;
0392 
0393     /* Don't let internal do_tcp_sendpages() flags through */
0394     flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
0395     flags |= MSG_NO_SHARED_FRAGS;
0396 
0397     psock = sk_psock_get(sk);
0398     if (unlikely(!psock))
0399         return tcp_sendmsg(sk, msg, size);
0400 
0401     lock_sock(sk);
0402     timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
0403     while (msg_data_left(msg)) {
0404         bool enospc = false;
0405         u32 copy, osize;
0406 
0407         if (sk->sk_err) {
0408             err = -sk->sk_err;
0409             goto out_err;
0410         }
0411 
0412         copy = msg_data_left(msg);
0413         if (!sk_stream_memory_free(sk))
0414             goto wait_for_sndbuf;
0415         if (psock->cork) {
0416             msg_tx = psock->cork;
0417         } else {
0418             msg_tx = &tmp;
0419             sk_msg_init(msg_tx);
0420         }
0421 
0422         osize = msg_tx->sg.size;
0423         err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
0424         if (err) {
0425             if (err != -ENOSPC)
0426                 goto wait_for_memory;
0427             enospc = true;
0428             copy = msg_tx->sg.size - osize;
0429         }
0430 
0431         err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
0432                            copy);
0433         if (err < 0) {
0434             sk_msg_trim(sk, msg_tx, osize);
0435             goto out_err;
0436         }
0437 
0438         copied += copy;
0439         if (psock->cork_bytes) {
0440             if (size > psock->cork_bytes)
0441                 psock->cork_bytes = 0;
0442             else
0443                 psock->cork_bytes -= size;
0444             if (psock->cork_bytes && !enospc)
0445                 goto out_err;
0446             /* All cork bytes are accounted, rerun the prog. */
0447             psock->eval = __SK_NONE;
0448             psock->cork_bytes = 0;
0449         }
0450 
0451         err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
0452         if (unlikely(err < 0))
0453             goto out_err;
0454         continue;
0455 wait_for_sndbuf:
0456         set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
0457 wait_for_memory:
0458         err = sk_stream_wait_memory(sk, &timeo);
0459         if (err) {
0460             if (msg_tx && msg_tx != psock->cork)
0461                 sk_msg_free(sk, msg_tx);
0462             goto out_err;
0463         }
0464     }
0465 out_err:
0466     if (err < 0)
0467         err = sk_stream_error(sk, msg->msg_flags, err);
0468     release_sock(sk);
0469     sk_psock_put(sk, psock);
0470     return copied ? copied : err;
0471 }
0472 
0473 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
0474                 size_t size, int flags)
0475 {
0476     struct sk_msg tmp, *msg = NULL;
0477     int err = 0, copied = 0;
0478     struct sk_psock *psock;
0479     bool enospc = false;
0480 
0481     psock = sk_psock_get(sk);
0482     if (unlikely(!psock))
0483         return tcp_sendpage(sk, page, offset, size, flags);
0484 
0485     lock_sock(sk);
0486     if (psock->cork) {
0487         msg = psock->cork;
0488     } else {
0489         msg = &tmp;
0490         sk_msg_init(msg);
0491     }
0492 
0493     /* Catch case where ring is full and sendpage is stalled. */
0494     if (unlikely(sk_msg_full(msg)))
0495         goto out_err;
0496 
0497     sk_msg_page_add(msg, page, size, offset);
0498     sk_mem_charge(sk, size);
0499     copied = size;
0500     if (sk_msg_full(msg))
0501         enospc = true;
0502     if (psock->cork_bytes) {
0503         if (size > psock->cork_bytes)
0504             psock->cork_bytes = 0;
0505         else
0506             psock->cork_bytes -= size;
0507         if (psock->cork_bytes && !enospc)
0508             goto out_err;
0509         /* All cork bytes are accounted, rerun the prog. */
0510         psock->eval = __SK_NONE;
0511         psock->cork_bytes = 0;
0512     }
0513 
0514     err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
0515 out_err:
0516     release_sock(sk);
0517     sk_psock_put(sk, psock);
0518     return copied ? copied : err;
0519 }
0520 
0521 enum {
0522     TCP_BPF_IPV4,
0523     TCP_BPF_IPV6,
0524     TCP_BPF_NUM_PROTS,
0525 };
0526 
0527 enum {
0528     TCP_BPF_BASE,
0529     TCP_BPF_TX,
0530     TCP_BPF_RX,
0531     TCP_BPF_TXRX,
0532     TCP_BPF_NUM_CFGS,
0533 };
0534 
0535 static struct proto *tcpv6_prot_saved __read_mostly;
0536 static DEFINE_SPINLOCK(tcpv6_prot_lock);
0537 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
0538 
0539 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
0540                    struct proto *base)
0541 {
0542     prot[TCP_BPF_BASE]          = *base;
0543     prot[TCP_BPF_BASE].destroy      = sock_map_destroy;
0544     prot[TCP_BPF_BASE].close        = sock_map_close;
0545     prot[TCP_BPF_BASE].recvmsg      = tcp_bpf_recvmsg;
0546     prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable;
0547 
0548     prot[TCP_BPF_TX]            = prot[TCP_BPF_BASE];
0549     prot[TCP_BPF_TX].sendmsg        = tcp_bpf_sendmsg;
0550     prot[TCP_BPF_TX].sendpage       = tcp_bpf_sendpage;
0551 
0552     prot[TCP_BPF_RX]            = prot[TCP_BPF_BASE];
0553     prot[TCP_BPF_RX].recvmsg        = tcp_bpf_recvmsg_parser;
0554 
0555     prot[TCP_BPF_TXRX]          = prot[TCP_BPF_TX];
0556     prot[TCP_BPF_TXRX].recvmsg      = tcp_bpf_recvmsg_parser;
0557 }
0558 
0559 static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
0560 {
0561     if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
0562         spin_lock_bh(&tcpv6_prot_lock);
0563         if (likely(ops != tcpv6_prot_saved)) {
0564             tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
0565             smp_store_release(&tcpv6_prot_saved, ops);
0566         }
0567         spin_unlock_bh(&tcpv6_prot_lock);
0568     }
0569 }
0570 
0571 static int __init tcp_bpf_v4_build_proto(void)
0572 {
0573     tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
0574     return 0;
0575 }
0576 late_initcall(tcp_bpf_v4_build_proto);
0577 
0578 static int tcp_bpf_assert_proto_ops(struct proto *ops)
0579 {
0580     /* In order to avoid retpoline, we make assumptions when we call
0581      * into ops if e.g. a psock is not present. Make sure they are
0582      * indeed valid assumptions.
0583      */
0584     return ops->recvmsg  == tcp_recvmsg &&
0585            ops->sendmsg  == tcp_sendmsg &&
0586            ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
0587 }
0588 
0589 int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
0590 {
0591     int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
0592     int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
0593 
0594     if (psock->progs.stream_verdict || psock->progs.skb_verdict) {
0595         config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX;
0596     }
0597 
0598     if (restore) {
0599         if (inet_csk_has_ulp(sk)) {
0600             /* TLS does not have an unhash proto in SW cases,
0601              * but we need to ensure we stop using the sock_map
0602              * unhash routine because the associated psock is being
0603              * removed. So use the original unhash handler.
0604              */
0605             WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash);
0606             tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
0607         } else {
0608             sk->sk_write_space = psock->saved_write_space;
0609             /* Pairs with lockless read in sk_clone_lock() */
0610             WRITE_ONCE(sk->sk_prot, psock->sk_proto);
0611         }
0612         return 0;
0613     }
0614 
0615     if (sk->sk_family == AF_INET6) {
0616         if (tcp_bpf_assert_proto_ops(psock->sk_proto))
0617             return -EINVAL;
0618 
0619         tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
0620     }
0621 
0622     /* Pairs with lockless read in sk_clone_lock() */
0623     WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
0624     return 0;
0625 }
0626 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
0627 
0628 /* If a child got cloned from a listening socket that had tcp_bpf
0629  * protocol callbacks installed, we need to restore the callbacks to
0630  * the default ones because the child does not inherit the psock state
0631  * that tcp_bpf callbacks expect.
0632  */
0633 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
0634 {
0635     int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
0636     struct proto *prot = newsk->sk_prot;
0637 
0638     if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
0639         newsk->sk_prot = sk->sk_prot_creator;
0640 }
0641 #endif /* CONFIG_BPF_SYSCALL */