0001
0002 #include <net/tcp.h>
0003 #include <net/strparser.h>
0004 #include <net/xfrm.h>
0005 #include <net/esp.h>
0006 #include <net/espintcp.h>
0007 #include <linux/skmsg.h>
0008 #include <net/inet_common.h>
0009 #if IS_ENABLED(CONFIG_IPV6)
0010 #include <net/ipv6_stubs.h>
0011 #endif
0012
0013 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
0014 struct sock *sk)
0015 {
0016 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
0017 !sk_rmem_schedule(sk, skb, skb->truesize)) {
0018 XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
0019 kfree_skb(skb);
0020 return;
0021 }
0022
0023 skb_set_owner_r(skb, sk);
0024
0025 memset(skb->cb, 0, sizeof(skb->cb));
0026 skb_queue_tail(&ctx->ike_queue, skb);
0027 ctx->saved_data_ready(sk);
0028 }
0029
0030 static void handle_esp(struct sk_buff *skb, struct sock *sk)
0031 {
0032 struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb;
0033
0034 skb_reset_transport_header(skb);
0035
0036
0037 memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header));
0038
0039 rcu_read_lock();
0040 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
0041 local_bh_disable();
0042 #if IS_ENABLED(CONFIG_IPV6)
0043 if (sk->sk_family == AF_INET6)
0044 ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
0045 else
0046 #endif
0047 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
0048 local_bh_enable();
0049 rcu_read_unlock();
0050 }
0051
0052 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
0053 {
0054 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
0055 strp);
0056 struct strp_msg *rxm = strp_msg(skb);
0057 int len = rxm->full_len - 2;
0058 u32 nonesp_marker;
0059 int err;
0060
0061
0062 if (unlikely(len == 1)) {
0063 u8 data;
0064
0065 err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
0066 if (err < 0) {
0067 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
0068 kfree_skb(skb);
0069 return;
0070 }
0071
0072 if (data == 0xff) {
0073 kfree_skb(skb);
0074 return;
0075 }
0076 }
0077
0078
0079 if (unlikely(len <= sizeof(nonesp_marker))) {
0080 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
0081 kfree_skb(skb);
0082 return;
0083 }
0084
0085 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
0086 sizeof(nonesp_marker));
0087 if (err < 0) {
0088 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
0089 kfree_skb(skb);
0090 return;
0091 }
0092
0093
0094 if (!__pskb_pull(skb, rxm->offset + 2)) {
0095 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
0096 kfree_skb(skb);
0097 return;
0098 }
0099
0100 if (pskb_trim(skb, rxm->full_len - 2) != 0) {
0101 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
0102 kfree_skb(skb);
0103 return;
0104 }
0105
0106 if (nonesp_marker == 0)
0107 handle_nonesp(ctx, skb, strp->sk);
0108 else
0109 handle_esp(skb, strp->sk);
0110 }
0111
0112 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
0113 {
0114 struct strp_msg *rxm = strp_msg(skb);
0115 __be16 blen;
0116 u16 len;
0117 int err;
0118
0119 if (skb->len < rxm->offset + 2)
0120 return 0;
0121
0122 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
0123 if (err < 0)
0124 return err;
0125
0126 len = be16_to_cpu(blen);
0127 if (len < 2)
0128 return -EINVAL;
0129
0130 return len;
0131 }
0132
0133 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
0134 int flags, int *addr_len)
0135 {
0136 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0137 struct sk_buff *skb;
0138 int err = 0;
0139 int copied;
0140 int off = 0;
0141
0142 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
0143 if (!skb) {
0144 if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
0145 return 0;
0146 return err;
0147 }
0148
0149 copied = len;
0150 if (copied > skb->len)
0151 copied = skb->len;
0152 else if (copied < skb->len)
0153 msg->msg_flags |= MSG_TRUNC;
0154
0155 err = skb_copy_datagram_msg(skb, 0, msg, copied);
0156 if (unlikely(err)) {
0157 kfree_skb(skb);
0158 return err;
0159 }
0160
0161 if (flags & MSG_TRUNC)
0162 copied = skb->len;
0163 kfree_skb(skb);
0164 return copied;
0165 }
0166
0167 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
0168 {
0169 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0170
0171 if (skb_queue_len(&ctx->out_queue) >= READ_ONCE(netdev_max_backlog))
0172 return -ENOBUFS;
0173
0174 __skb_queue_tail(&ctx->out_queue, skb);
0175
0176 return 0;
0177 }
0178 EXPORT_SYMBOL_GPL(espintcp_queue_out);
0179
0180
0181 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
0182
0183 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
0184 int flags)
0185 {
0186 do {
0187 int ret;
0188
0189 ret = skb_send_sock_locked(sk, emsg->skb,
0190 emsg->offset, emsg->len);
0191 if (ret < 0)
0192 return ret;
0193
0194 emsg->len -= ret;
0195 emsg->offset += ret;
0196 } while (emsg->len > 0);
0197
0198 kfree_skb(emsg->skb);
0199 memset(emsg, 0, sizeof(*emsg));
0200
0201 return 0;
0202 }
0203
0204 static int espintcp_sendskmsg_locked(struct sock *sk,
0205 struct espintcp_msg *emsg, int flags)
0206 {
0207 struct sk_msg *skmsg = &emsg->skmsg;
0208 struct scatterlist *sg;
0209 int done = 0;
0210 int ret;
0211
0212 flags |= MSG_SENDPAGE_NOTLAST;
0213 sg = &skmsg->sg.data[skmsg->sg.start];
0214 do {
0215 size_t size = sg->length - emsg->offset;
0216 int offset = sg->offset + emsg->offset;
0217 struct page *p;
0218
0219 emsg->offset = 0;
0220
0221 if (sg_is_last(sg))
0222 flags &= ~MSG_SENDPAGE_NOTLAST;
0223
0224 p = sg_page(sg);
0225 retry:
0226 ret = do_tcp_sendpages(sk, p, offset, size, flags);
0227 if (ret < 0) {
0228 emsg->offset = offset - sg->offset;
0229 skmsg->sg.start += done;
0230 return ret;
0231 }
0232
0233 if (ret != size) {
0234 offset += ret;
0235 size -= ret;
0236 goto retry;
0237 }
0238
0239 done++;
0240 put_page(p);
0241 sk_mem_uncharge(sk, sg->length);
0242 sg = sg_next(sg);
0243 } while (sg);
0244
0245 memset(emsg, 0, sizeof(*emsg));
0246
0247 return 0;
0248 }
0249
0250 static int espintcp_push_msgs(struct sock *sk, int flags)
0251 {
0252 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0253 struct espintcp_msg *emsg = &ctx->partial;
0254 int err;
0255
0256 if (!emsg->len)
0257 return 0;
0258
0259 if (ctx->tx_running)
0260 return -EAGAIN;
0261 ctx->tx_running = 1;
0262
0263 if (emsg->skb)
0264 err = espintcp_sendskb_locked(sk, emsg, flags);
0265 else
0266 err = espintcp_sendskmsg_locked(sk, emsg, flags);
0267 if (err == -EAGAIN) {
0268 ctx->tx_running = 0;
0269 return flags & MSG_DONTWAIT ? -EAGAIN : 0;
0270 }
0271 if (!err)
0272 memset(emsg, 0, sizeof(*emsg));
0273
0274 ctx->tx_running = 0;
0275
0276 return err;
0277 }
0278
0279 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
0280 {
0281 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0282 struct espintcp_msg *emsg = &ctx->partial;
0283 unsigned int len;
0284 int offset;
0285
0286 if (sk->sk_state != TCP_ESTABLISHED) {
0287 kfree_skb(skb);
0288 return -ECONNRESET;
0289 }
0290
0291 offset = skb_transport_offset(skb);
0292 len = skb->len - offset;
0293
0294 espintcp_push_msgs(sk, 0);
0295
0296 if (emsg->len) {
0297 kfree_skb(skb);
0298 return -ENOBUFS;
0299 }
0300
0301 skb_set_owner_w(skb, sk);
0302
0303 emsg->offset = offset;
0304 emsg->len = len;
0305 emsg->skb = skb;
0306
0307 espintcp_push_msgs(sk, 0);
0308
0309 return 0;
0310 }
0311 EXPORT_SYMBOL_GPL(espintcp_push_skb);
0312
0313 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
0314 {
0315 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
0316 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0317 struct espintcp_msg *emsg = &ctx->partial;
0318 struct iov_iter pfx_iter;
0319 struct kvec pfx_iov = {};
0320 size_t msglen = size + 2;
0321 char buf[2] = {0};
0322 int err, end;
0323
0324 if (msg->msg_flags & ~MSG_DONTWAIT)
0325 return -EOPNOTSUPP;
0326
0327 if (size > MAX_ESPINTCP_MSG)
0328 return -EMSGSIZE;
0329
0330 if (msg->msg_controllen)
0331 return -EOPNOTSUPP;
0332
0333 lock_sock(sk);
0334
0335 err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
0336 if (err < 0) {
0337 if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
0338 err = -ENOBUFS;
0339 goto unlock;
0340 }
0341
0342 sk_msg_init(&emsg->skmsg);
0343 while (1) {
0344
0345 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
0346 if (!err)
0347 break;
0348
0349 err = sk_stream_wait_memory(sk, &timeo);
0350 if (err)
0351 goto fail;
0352 }
0353
0354 *((__be16 *)buf) = cpu_to_be16(msglen);
0355 pfx_iov.iov_base = buf;
0356 pfx_iov.iov_len = sizeof(buf);
0357 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
0358
0359 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
0360 pfx_iov.iov_len);
0361 if (err < 0)
0362 goto fail;
0363
0364 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
0365 if (err < 0)
0366 goto fail;
0367
0368 end = emsg->skmsg.sg.end;
0369 emsg->len = size;
0370 sk_msg_iter_var_prev(end);
0371 sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
0372
0373 tcp_rate_check_app_limited(sk);
0374
0375 err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
0376
0377
0378 release_sock(sk);
0379
0380 return size;
0381
0382 fail:
0383 sk_msg_free(sk, &emsg->skmsg);
0384 memset(emsg, 0, sizeof(*emsg));
0385 unlock:
0386 release_sock(sk);
0387 return err;
0388 }
0389
0390 static struct proto espintcp_prot __ro_after_init;
0391 static struct proto_ops espintcp_ops __ro_after_init;
0392 static struct proto espintcp6_prot;
0393 static struct proto_ops espintcp6_ops;
0394 static DEFINE_MUTEX(tcpv6_prot_mutex);
0395
0396 static void espintcp_data_ready(struct sock *sk)
0397 {
0398 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0399
0400 strp_data_ready(&ctx->strp);
0401 }
0402
0403 static void espintcp_tx_work(struct work_struct *work)
0404 {
0405 struct espintcp_ctx *ctx = container_of(work,
0406 struct espintcp_ctx, work);
0407 struct sock *sk = ctx->strp.sk;
0408
0409 lock_sock(sk);
0410 if (!ctx->tx_running)
0411 espintcp_push_msgs(sk, 0);
0412 release_sock(sk);
0413 }
0414
0415 static void espintcp_write_space(struct sock *sk)
0416 {
0417 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0418
0419 schedule_work(&ctx->work);
0420 ctx->saved_write_space(sk);
0421 }
0422
0423 static void espintcp_destruct(struct sock *sk)
0424 {
0425 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0426
0427 ctx->saved_destruct(sk);
0428 kfree(ctx);
0429 }
0430
0431 bool tcp_is_ulp_esp(struct sock *sk)
0432 {
0433 return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
0434 }
0435 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
0436
0437 static void build_protos(struct proto *espintcp_prot,
0438 struct proto_ops *espintcp_ops,
0439 const struct proto *orig_prot,
0440 const struct proto_ops *orig_ops);
0441 static int espintcp_init_sk(struct sock *sk)
0442 {
0443 struct inet_connection_sock *icsk = inet_csk(sk);
0444 struct strp_callbacks cb = {
0445 .rcv_msg = espintcp_rcv,
0446 .parse_msg = espintcp_parse,
0447 };
0448 struct espintcp_ctx *ctx;
0449 int err;
0450
0451
0452 if (sk->sk_user_data)
0453 return -EBUSY;
0454
0455 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
0456 if (!ctx)
0457 return -ENOMEM;
0458
0459 err = strp_init(&ctx->strp, sk, &cb);
0460 if (err)
0461 goto free;
0462
0463 __sk_dst_reset(sk);
0464
0465 strp_check_rcv(&ctx->strp);
0466 skb_queue_head_init(&ctx->ike_queue);
0467 skb_queue_head_init(&ctx->out_queue);
0468
0469 if (sk->sk_family == AF_INET) {
0470 sk->sk_prot = &espintcp_prot;
0471 sk->sk_socket->ops = &espintcp_ops;
0472 } else {
0473 mutex_lock(&tcpv6_prot_mutex);
0474 if (!espintcp6_prot.recvmsg)
0475 build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
0476 mutex_unlock(&tcpv6_prot_mutex);
0477
0478 sk->sk_prot = &espintcp6_prot;
0479 sk->sk_socket->ops = &espintcp6_ops;
0480 }
0481 ctx->saved_data_ready = sk->sk_data_ready;
0482 ctx->saved_write_space = sk->sk_write_space;
0483 ctx->saved_destruct = sk->sk_destruct;
0484 sk->sk_data_ready = espintcp_data_ready;
0485 sk->sk_write_space = espintcp_write_space;
0486 sk->sk_destruct = espintcp_destruct;
0487 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
0488 INIT_WORK(&ctx->work, espintcp_tx_work);
0489
0490
0491 sk->sk_allocation = GFP_ATOMIC;
0492
0493 return 0;
0494
0495 free:
0496 kfree(ctx);
0497 return err;
0498 }
0499
0500 static void espintcp_release(struct sock *sk)
0501 {
0502 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0503 struct sk_buff_head queue;
0504 struct sk_buff *skb;
0505
0506 __skb_queue_head_init(&queue);
0507 skb_queue_splice_init(&ctx->out_queue, &queue);
0508
0509 while ((skb = __skb_dequeue(&queue)))
0510 espintcp_push_skb(sk, skb);
0511
0512 tcp_release_cb(sk);
0513 }
0514
0515 static void espintcp_close(struct sock *sk, long timeout)
0516 {
0517 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0518 struct espintcp_msg *emsg = &ctx->partial;
0519
0520 strp_stop(&ctx->strp);
0521
0522 sk->sk_prot = &tcp_prot;
0523 barrier();
0524
0525 cancel_work_sync(&ctx->work);
0526 strp_done(&ctx->strp);
0527
0528 skb_queue_purge(&ctx->out_queue);
0529 skb_queue_purge(&ctx->ike_queue);
0530
0531 if (emsg->len) {
0532 if (emsg->skb)
0533 kfree_skb(emsg->skb);
0534 else
0535 sk_msg_free(sk, &emsg->skmsg);
0536 }
0537
0538 tcp_close(sk, timeout);
0539 }
0540
0541 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
0542 poll_table *wait)
0543 {
0544 __poll_t mask = datagram_poll(file, sock, wait);
0545 struct sock *sk = sock->sk;
0546 struct espintcp_ctx *ctx = espintcp_getctx(sk);
0547
0548 if (!skb_queue_empty(&ctx->ike_queue))
0549 mask |= EPOLLIN | EPOLLRDNORM;
0550
0551 return mask;
0552 }
0553
0554 static void build_protos(struct proto *espintcp_prot,
0555 struct proto_ops *espintcp_ops,
0556 const struct proto *orig_prot,
0557 const struct proto_ops *orig_ops)
0558 {
0559 memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
0560 memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
0561 espintcp_prot->sendmsg = espintcp_sendmsg;
0562 espintcp_prot->recvmsg = espintcp_recvmsg;
0563 espintcp_prot->close = espintcp_close;
0564 espintcp_prot->release_cb = espintcp_release;
0565 espintcp_ops->poll = espintcp_poll;
0566 }
0567
0568 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
0569 .name = "espintcp",
0570 .owner = THIS_MODULE,
0571 .init = espintcp_init_sk,
0572 };
0573
0574 void __init espintcp_init(void)
0575 {
0576 build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
0577
0578 tcp_register_ulp(&espintcp_ulp);
0579 }