Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 #include <net/tcp.h>
0003 #include <net/strparser.h>
0004 #include <net/xfrm.h>
0005 #include <net/esp.h>
0006 #include <net/espintcp.h>
0007 #include <linux/skmsg.h>
0008 #include <net/inet_common.h>
0009 #if IS_ENABLED(CONFIG_IPV6)
0010 #include <net/ipv6_stubs.h>
0011 #endif
0012 
0013 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
0014               struct sock *sk)
0015 {
0016     if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
0017         !sk_rmem_schedule(sk, skb, skb->truesize)) {
0018         XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
0019         kfree_skb(skb);
0020         return;
0021     }
0022 
0023     skb_set_owner_r(skb, sk);
0024 
0025     memset(skb->cb, 0, sizeof(skb->cb));
0026     skb_queue_tail(&ctx->ike_queue, skb);
0027     ctx->saved_data_ready(sk);
0028 }
0029 
0030 static void handle_esp(struct sk_buff *skb, struct sock *sk)
0031 {
0032     struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb;
0033 
0034     skb_reset_transport_header(skb);
0035 
0036     /* restore IP CB, we need at least IP6CB->nhoff */
0037     memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header));
0038 
0039     rcu_read_lock();
0040     skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
0041     local_bh_disable();
0042 #if IS_ENABLED(CONFIG_IPV6)
0043     if (sk->sk_family == AF_INET6)
0044         ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
0045     else
0046 #endif
0047         xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
0048     local_bh_enable();
0049     rcu_read_unlock();
0050 }
0051 
0052 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
0053 {
0054     struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
0055                         strp);
0056     struct strp_msg *rxm = strp_msg(skb);
0057     int len = rxm->full_len - 2;
0058     u32 nonesp_marker;
0059     int err;
0060 
0061     /* keepalive packet? */
0062     if (unlikely(len == 1)) {
0063         u8 data;
0064 
0065         err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
0066         if (err < 0) {
0067             XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
0068             kfree_skb(skb);
0069             return;
0070         }
0071 
0072         if (data == 0xff) {
0073             kfree_skb(skb);
0074             return;
0075         }
0076     }
0077 
0078     /* drop other short messages */
0079     if (unlikely(len <= sizeof(nonesp_marker))) {
0080         XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
0081         kfree_skb(skb);
0082         return;
0083     }
0084 
0085     err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
0086                 sizeof(nonesp_marker));
0087     if (err < 0) {
0088         XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
0089         kfree_skb(skb);
0090         return;
0091     }
0092 
0093     /* remove header, leave non-ESP marker/SPI */
0094     if (!__pskb_pull(skb, rxm->offset + 2)) {
0095         XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
0096         kfree_skb(skb);
0097         return;
0098     }
0099 
0100     if (pskb_trim(skb, rxm->full_len - 2) != 0) {
0101         XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
0102         kfree_skb(skb);
0103         return;
0104     }
0105 
0106     if (nonesp_marker == 0)
0107         handle_nonesp(ctx, skb, strp->sk);
0108     else
0109         handle_esp(skb, strp->sk);
0110 }
0111 
0112 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
0113 {
0114     struct strp_msg *rxm = strp_msg(skb);
0115     __be16 blen;
0116     u16 len;
0117     int err;
0118 
0119     if (skb->len < rxm->offset + 2)
0120         return 0;
0121 
0122     err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
0123     if (err < 0)
0124         return err;
0125 
0126     len = be16_to_cpu(blen);
0127     if (len < 2)
0128         return -EINVAL;
0129 
0130     return len;
0131 }
0132 
0133 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
0134                 int flags, int *addr_len)
0135 {
0136     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0137     struct sk_buff *skb;
0138     int err = 0;
0139     int copied;
0140     int off = 0;
0141 
0142     skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
0143     if (!skb) {
0144         if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
0145             return 0;
0146         return err;
0147     }
0148 
0149     copied = len;
0150     if (copied > skb->len)
0151         copied = skb->len;
0152     else if (copied < skb->len)
0153         msg->msg_flags |= MSG_TRUNC;
0154 
0155     err = skb_copy_datagram_msg(skb, 0, msg, copied);
0156     if (unlikely(err)) {
0157         kfree_skb(skb);
0158         return err;
0159     }
0160 
0161     if (flags & MSG_TRUNC)
0162         copied = skb->len;
0163     kfree_skb(skb);
0164     return copied;
0165 }
0166 
0167 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
0168 {
0169     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0170 
0171     if (skb_queue_len(&ctx->out_queue) >= READ_ONCE(netdev_max_backlog))
0172         return -ENOBUFS;
0173 
0174     __skb_queue_tail(&ctx->out_queue, skb);
0175 
0176     return 0;
0177 }
0178 EXPORT_SYMBOL_GPL(espintcp_queue_out);
0179 
0180 /* espintcp length field is 2B and length includes the length field's size */
0181 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
0182 
0183 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
0184                    int flags)
0185 {
0186     do {
0187         int ret;
0188 
0189         ret = skb_send_sock_locked(sk, emsg->skb,
0190                        emsg->offset, emsg->len);
0191         if (ret < 0)
0192             return ret;
0193 
0194         emsg->len -= ret;
0195         emsg->offset += ret;
0196     } while (emsg->len > 0);
0197 
0198     kfree_skb(emsg->skb);
0199     memset(emsg, 0, sizeof(*emsg));
0200 
0201     return 0;
0202 }
0203 
0204 static int espintcp_sendskmsg_locked(struct sock *sk,
0205                      struct espintcp_msg *emsg, int flags)
0206 {
0207     struct sk_msg *skmsg = &emsg->skmsg;
0208     struct scatterlist *sg;
0209     int done = 0;
0210     int ret;
0211 
0212     flags |= MSG_SENDPAGE_NOTLAST;
0213     sg = &skmsg->sg.data[skmsg->sg.start];
0214     do {
0215         size_t size = sg->length - emsg->offset;
0216         int offset = sg->offset + emsg->offset;
0217         struct page *p;
0218 
0219         emsg->offset = 0;
0220 
0221         if (sg_is_last(sg))
0222             flags &= ~MSG_SENDPAGE_NOTLAST;
0223 
0224         p = sg_page(sg);
0225 retry:
0226         ret = do_tcp_sendpages(sk, p, offset, size, flags);
0227         if (ret < 0) {
0228             emsg->offset = offset - sg->offset;
0229             skmsg->sg.start += done;
0230             return ret;
0231         }
0232 
0233         if (ret != size) {
0234             offset += ret;
0235             size -= ret;
0236             goto retry;
0237         }
0238 
0239         done++;
0240         put_page(p);
0241         sk_mem_uncharge(sk, sg->length);
0242         sg = sg_next(sg);
0243     } while (sg);
0244 
0245     memset(emsg, 0, sizeof(*emsg));
0246 
0247     return 0;
0248 }
0249 
0250 static int espintcp_push_msgs(struct sock *sk, int flags)
0251 {
0252     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0253     struct espintcp_msg *emsg = &ctx->partial;
0254     int err;
0255 
0256     if (!emsg->len)
0257         return 0;
0258 
0259     if (ctx->tx_running)
0260         return -EAGAIN;
0261     ctx->tx_running = 1;
0262 
0263     if (emsg->skb)
0264         err = espintcp_sendskb_locked(sk, emsg, flags);
0265     else
0266         err = espintcp_sendskmsg_locked(sk, emsg, flags);
0267     if (err == -EAGAIN) {
0268         ctx->tx_running = 0;
0269         return flags & MSG_DONTWAIT ? -EAGAIN : 0;
0270     }
0271     if (!err)
0272         memset(emsg, 0, sizeof(*emsg));
0273 
0274     ctx->tx_running = 0;
0275 
0276     return err;
0277 }
0278 
0279 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
0280 {
0281     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0282     struct espintcp_msg *emsg = &ctx->partial;
0283     unsigned int len;
0284     int offset;
0285 
0286     if (sk->sk_state != TCP_ESTABLISHED) {
0287         kfree_skb(skb);
0288         return -ECONNRESET;
0289     }
0290 
0291     offset = skb_transport_offset(skb);
0292     len = skb->len - offset;
0293 
0294     espintcp_push_msgs(sk, 0);
0295 
0296     if (emsg->len) {
0297         kfree_skb(skb);
0298         return -ENOBUFS;
0299     }
0300 
0301     skb_set_owner_w(skb, sk);
0302 
0303     emsg->offset = offset;
0304     emsg->len = len;
0305     emsg->skb = skb;
0306 
0307     espintcp_push_msgs(sk, 0);
0308 
0309     return 0;
0310 }
0311 EXPORT_SYMBOL_GPL(espintcp_push_skb);
0312 
0313 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
0314 {
0315     long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
0316     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0317     struct espintcp_msg *emsg = &ctx->partial;
0318     struct iov_iter pfx_iter;
0319     struct kvec pfx_iov = {};
0320     size_t msglen = size + 2;
0321     char buf[2] = {0};
0322     int err, end;
0323 
0324     if (msg->msg_flags & ~MSG_DONTWAIT)
0325         return -EOPNOTSUPP;
0326 
0327     if (size > MAX_ESPINTCP_MSG)
0328         return -EMSGSIZE;
0329 
0330     if (msg->msg_controllen)
0331         return -EOPNOTSUPP;
0332 
0333     lock_sock(sk);
0334 
0335     err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
0336     if (err < 0) {
0337         if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
0338             err = -ENOBUFS;
0339         goto unlock;
0340     }
0341 
0342     sk_msg_init(&emsg->skmsg);
0343     while (1) {
0344         /* only -ENOMEM is possible since we don't coalesce */
0345         err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
0346         if (!err)
0347             break;
0348 
0349         err = sk_stream_wait_memory(sk, &timeo);
0350         if (err)
0351             goto fail;
0352     }
0353 
0354     *((__be16 *)buf) = cpu_to_be16(msglen);
0355     pfx_iov.iov_base = buf;
0356     pfx_iov.iov_len = sizeof(buf);
0357     iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
0358 
0359     err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
0360                        pfx_iov.iov_len);
0361     if (err < 0)
0362         goto fail;
0363 
0364     err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
0365     if (err < 0)
0366         goto fail;
0367 
0368     end = emsg->skmsg.sg.end;
0369     emsg->len = size;
0370     sk_msg_iter_var_prev(end);
0371     sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
0372 
0373     tcp_rate_check_app_limited(sk);
0374 
0375     err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
0376     /* this message could be partially sent, keep it */
0377 
0378     release_sock(sk);
0379 
0380     return size;
0381 
0382 fail:
0383     sk_msg_free(sk, &emsg->skmsg);
0384     memset(emsg, 0, sizeof(*emsg));
0385 unlock:
0386     release_sock(sk);
0387     return err;
0388 }
0389 
0390 static struct proto espintcp_prot __ro_after_init;
0391 static struct proto_ops espintcp_ops __ro_after_init;
0392 static struct proto espintcp6_prot;
0393 static struct proto_ops espintcp6_ops;
0394 static DEFINE_MUTEX(tcpv6_prot_mutex);
0395 
0396 static void espintcp_data_ready(struct sock *sk)
0397 {
0398     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0399 
0400     strp_data_ready(&ctx->strp);
0401 }
0402 
0403 static void espintcp_tx_work(struct work_struct *work)
0404 {
0405     struct espintcp_ctx *ctx = container_of(work,
0406                         struct espintcp_ctx, work);
0407     struct sock *sk = ctx->strp.sk;
0408 
0409     lock_sock(sk);
0410     if (!ctx->tx_running)
0411         espintcp_push_msgs(sk, 0);
0412     release_sock(sk);
0413 }
0414 
0415 static void espintcp_write_space(struct sock *sk)
0416 {
0417     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0418 
0419     schedule_work(&ctx->work);
0420     ctx->saved_write_space(sk);
0421 }
0422 
0423 static void espintcp_destruct(struct sock *sk)
0424 {
0425     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0426 
0427     ctx->saved_destruct(sk);
0428     kfree(ctx);
0429 }
0430 
0431 bool tcp_is_ulp_esp(struct sock *sk)
0432 {
0433     return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
0434 }
0435 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
0436 
0437 static void build_protos(struct proto *espintcp_prot,
0438              struct proto_ops *espintcp_ops,
0439              const struct proto *orig_prot,
0440              const struct proto_ops *orig_ops);
0441 static int espintcp_init_sk(struct sock *sk)
0442 {
0443     struct inet_connection_sock *icsk = inet_csk(sk);
0444     struct strp_callbacks cb = {
0445         .rcv_msg = espintcp_rcv,
0446         .parse_msg = espintcp_parse,
0447     };
0448     struct espintcp_ctx *ctx;
0449     int err;
0450 
0451     /* sockmap is not compatible with espintcp */
0452     if (sk->sk_user_data)
0453         return -EBUSY;
0454 
0455     ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
0456     if (!ctx)
0457         return -ENOMEM;
0458 
0459     err = strp_init(&ctx->strp, sk, &cb);
0460     if (err)
0461         goto free;
0462 
0463     __sk_dst_reset(sk);
0464 
0465     strp_check_rcv(&ctx->strp);
0466     skb_queue_head_init(&ctx->ike_queue);
0467     skb_queue_head_init(&ctx->out_queue);
0468 
0469     if (sk->sk_family == AF_INET) {
0470         sk->sk_prot = &espintcp_prot;
0471         sk->sk_socket->ops = &espintcp_ops;
0472     } else {
0473         mutex_lock(&tcpv6_prot_mutex);
0474         if (!espintcp6_prot.recvmsg)
0475             build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
0476         mutex_unlock(&tcpv6_prot_mutex);
0477 
0478         sk->sk_prot = &espintcp6_prot;
0479         sk->sk_socket->ops = &espintcp6_ops;
0480     }
0481     ctx->saved_data_ready = sk->sk_data_ready;
0482     ctx->saved_write_space = sk->sk_write_space;
0483     ctx->saved_destruct = sk->sk_destruct;
0484     sk->sk_data_ready = espintcp_data_ready;
0485     sk->sk_write_space = espintcp_write_space;
0486     sk->sk_destruct = espintcp_destruct;
0487     rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
0488     INIT_WORK(&ctx->work, espintcp_tx_work);
0489 
0490     /* avoid using task_frag */
0491     sk->sk_allocation = GFP_ATOMIC;
0492 
0493     return 0;
0494 
0495 free:
0496     kfree(ctx);
0497     return err;
0498 }
0499 
0500 static void espintcp_release(struct sock *sk)
0501 {
0502     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0503     struct sk_buff_head queue;
0504     struct sk_buff *skb;
0505 
0506     __skb_queue_head_init(&queue);
0507     skb_queue_splice_init(&ctx->out_queue, &queue);
0508 
0509     while ((skb = __skb_dequeue(&queue)))
0510         espintcp_push_skb(sk, skb);
0511 
0512     tcp_release_cb(sk);
0513 }
0514 
0515 static void espintcp_close(struct sock *sk, long timeout)
0516 {
0517     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0518     struct espintcp_msg *emsg = &ctx->partial;
0519 
0520     strp_stop(&ctx->strp);
0521 
0522     sk->sk_prot = &tcp_prot;
0523     barrier();
0524 
0525     cancel_work_sync(&ctx->work);
0526     strp_done(&ctx->strp);
0527 
0528     skb_queue_purge(&ctx->out_queue);
0529     skb_queue_purge(&ctx->ike_queue);
0530 
0531     if (emsg->len) {
0532         if (emsg->skb)
0533             kfree_skb(emsg->skb);
0534         else
0535             sk_msg_free(sk, &emsg->skmsg);
0536     }
0537 
0538     tcp_close(sk, timeout);
0539 }
0540 
0541 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
0542                   poll_table *wait)
0543 {
0544     __poll_t mask = datagram_poll(file, sock, wait);
0545     struct sock *sk = sock->sk;
0546     struct espintcp_ctx *ctx = espintcp_getctx(sk);
0547 
0548     if (!skb_queue_empty(&ctx->ike_queue))
0549         mask |= EPOLLIN | EPOLLRDNORM;
0550 
0551     return mask;
0552 }
0553 
0554 static void build_protos(struct proto *espintcp_prot,
0555              struct proto_ops *espintcp_ops,
0556              const struct proto *orig_prot,
0557              const struct proto_ops *orig_ops)
0558 {
0559     memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
0560     memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
0561     espintcp_prot->sendmsg = espintcp_sendmsg;
0562     espintcp_prot->recvmsg = espintcp_recvmsg;
0563     espintcp_prot->close = espintcp_close;
0564     espintcp_prot->release_cb = espintcp_release;
0565     espintcp_ops->poll = espintcp_poll;
0566 }
0567 
0568 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
0569     .name = "espintcp",
0570     .owner = THIS_MODULE,
0571     .init = espintcp_init_sk,
0572 };
0573 
0574 void __init espintcp_init(void)
0575 {
0576     build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
0577 
0578     tcp_register_ulp(&espintcp_ulp);
0579 }