0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034 #include <linux/module.h>
0035
0036 #include <net/tcp.h>
0037 #include <net/inet_common.h>
0038 #include <linux/highmem.h>
0039 #include <linux/netdevice.h>
0040 #include <linux/sched/signal.h>
0041 #include <linux/inetdevice.h>
0042 #include <linux/inet_diag.h>
0043
0044 #include <net/snmp.h>
0045 #include <net/tls.h>
0046 #include <net/tls_toe.h>
0047
0048 #include "tls.h"
0049
0050 MODULE_AUTHOR("Mellanox Technologies");
0051 MODULE_DESCRIPTION("Transport Layer Security Support");
0052 MODULE_LICENSE("Dual BSD/GPL");
0053 MODULE_ALIAS_TCP_ULP("tls");
0054
0055 enum {
0056 TLSV4,
0057 TLSV6,
0058 TLS_NUM_PROTS,
0059 };
0060
0061 static const struct proto *saved_tcpv6_prot;
0062 static DEFINE_MUTEX(tcpv6_prot_mutex);
0063 static const struct proto *saved_tcpv4_prot;
0064 static DEFINE_MUTEX(tcpv4_prot_mutex);
0065 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
0066 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
0067 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
0068 const struct proto *base);
0069
0070 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
0071 {
0072 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
0073
0074 WRITE_ONCE(sk->sk_prot,
0075 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
0076 WRITE_ONCE(sk->sk_socket->ops,
0077 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
0078 }
0079
0080 int wait_on_pending_writer(struct sock *sk, long *timeo)
0081 {
0082 int rc = 0;
0083 DEFINE_WAIT_FUNC(wait, woken_wake_function);
0084
0085 add_wait_queue(sk_sleep(sk), &wait);
0086 while (1) {
0087 if (!*timeo) {
0088 rc = -EAGAIN;
0089 break;
0090 }
0091
0092 if (signal_pending(current)) {
0093 rc = sock_intr_errno(*timeo);
0094 break;
0095 }
0096
0097 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
0098 break;
0099 }
0100 remove_wait_queue(sk_sleep(sk), &wait);
0101 return rc;
0102 }
0103
0104 int tls_push_sg(struct sock *sk,
0105 struct tls_context *ctx,
0106 struct scatterlist *sg,
0107 u16 first_offset,
0108 int flags)
0109 {
0110 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
0111 int ret = 0;
0112 struct page *p;
0113 size_t size;
0114 int offset = first_offset;
0115
0116 size = sg->length - offset;
0117 offset += sg->offset;
0118
0119 ctx->in_tcp_sendpages = true;
0120 while (1) {
0121 if (sg_is_last(sg))
0122 sendpage_flags = flags;
0123
0124
0125 tcp_rate_check_app_limited(sk);
0126 p = sg_page(sg);
0127 retry:
0128 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
0129
0130 if (ret != size) {
0131 if (ret > 0) {
0132 offset += ret;
0133 size -= ret;
0134 goto retry;
0135 }
0136
0137 offset -= sg->offset;
0138 ctx->partially_sent_offset = offset;
0139 ctx->partially_sent_record = (void *)sg;
0140 ctx->in_tcp_sendpages = false;
0141 return ret;
0142 }
0143
0144 put_page(p);
0145 sk_mem_uncharge(sk, sg->length);
0146 sg = sg_next(sg);
0147 if (!sg)
0148 break;
0149
0150 offset = sg->offset;
0151 size = sg->length;
0152 }
0153
0154 ctx->in_tcp_sendpages = false;
0155
0156 return 0;
0157 }
0158
0159 static int tls_handle_open_record(struct sock *sk, int flags)
0160 {
0161 struct tls_context *ctx = tls_get_ctx(sk);
0162
0163 if (tls_is_pending_open_record(ctx))
0164 return ctx->push_pending_record(sk, flags);
0165
0166 return 0;
0167 }
0168
0169 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
0170 unsigned char *record_type)
0171 {
0172 struct cmsghdr *cmsg;
0173 int rc = -EINVAL;
0174
0175 for_each_cmsghdr(cmsg, msg) {
0176 if (!CMSG_OK(msg, cmsg))
0177 return -EINVAL;
0178 if (cmsg->cmsg_level != SOL_TLS)
0179 continue;
0180
0181 switch (cmsg->cmsg_type) {
0182 case TLS_SET_RECORD_TYPE:
0183 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
0184 return -EINVAL;
0185
0186 if (msg->msg_flags & MSG_MORE)
0187 return -EINVAL;
0188
0189 rc = tls_handle_open_record(sk, msg->msg_flags);
0190 if (rc)
0191 return rc;
0192
0193 *record_type = *(unsigned char *)CMSG_DATA(cmsg);
0194 rc = 0;
0195 break;
0196 default:
0197 return -EINVAL;
0198 }
0199 }
0200
0201 return rc;
0202 }
0203
0204 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
0205 int flags)
0206 {
0207 struct scatterlist *sg;
0208 u16 offset;
0209
0210 sg = ctx->partially_sent_record;
0211 offset = ctx->partially_sent_offset;
0212
0213 ctx->partially_sent_record = NULL;
0214 return tls_push_sg(sk, ctx, sg, offset, flags);
0215 }
0216
0217 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
0218 {
0219 struct scatterlist *sg;
0220
0221 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
0222 put_page(sg_page(sg));
0223 sk_mem_uncharge(sk, sg->length);
0224 }
0225 ctx->partially_sent_record = NULL;
0226 }
0227
0228 static void tls_write_space(struct sock *sk)
0229 {
0230 struct tls_context *ctx = tls_get_ctx(sk);
0231
0232
0233
0234
0235
0236 if (ctx->in_tcp_sendpages) {
0237 ctx->sk_write_space(sk);
0238 return;
0239 }
0240
0241 #ifdef CONFIG_TLS_DEVICE
0242 if (ctx->tx_conf == TLS_HW)
0243 tls_device_write_space(sk, ctx);
0244 else
0245 #endif
0246 tls_sw_write_space(sk, ctx);
0247
0248 ctx->sk_write_space(sk);
0249 }
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
0260 {
0261 if (!ctx)
0262 return;
0263
0264 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
0265 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
0266 mutex_destroy(&ctx->tx_lock);
0267
0268 if (sk)
0269 kfree_rcu(ctx, rcu);
0270 else
0271 kfree(ctx);
0272 }
0273
0274 static void tls_sk_proto_cleanup(struct sock *sk,
0275 struct tls_context *ctx, long timeo)
0276 {
0277 if (unlikely(sk->sk_write_pending) &&
0278 !wait_on_pending_writer(sk, &timeo))
0279 tls_handle_open_record(sk, 0);
0280
0281
0282 if (ctx->tx_conf == TLS_SW) {
0283 kfree(ctx->tx.rec_seq);
0284 kfree(ctx->tx.iv);
0285 tls_sw_release_resources_tx(sk);
0286 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
0287 } else if (ctx->tx_conf == TLS_HW) {
0288 tls_device_free_resources_tx(sk);
0289 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
0290 }
0291
0292 if (ctx->rx_conf == TLS_SW) {
0293 tls_sw_release_resources_rx(sk);
0294 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
0295 } else if (ctx->rx_conf == TLS_HW) {
0296 tls_device_offload_cleanup_rx(sk);
0297 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
0298 }
0299 }
0300
0301 static void tls_sk_proto_close(struct sock *sk, long timeout)
0302 {
0303 struct inet_connection_sock *icsk = inet_csk(sk);
0304 struct tls_context *ctx = tls_get_ctx(sk);
0305 long timeo = sock_sndtimeo(sk, 0);
0306 bool free_ctx;
0307
0308 if (ctx->tx_conf == TLS_SW)
0309 tls_sw_cancel_work_tx(ctx);
0310
0311 lock_sock(sk);
0312 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
0313
0314 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
0315 tls_sk_proto_cleanup(sk, ctx, timeo);
0316
0317 write_lock_bh(&sk->sk_callback_lock);
0318 if (free_ctx)
0319 rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
0320 WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
0321 if (sk->sk_write_space == tls_write_space)
0322 sk->sk_write_space = ctx->sk_write_space;
0323 write_unlock_bh(&sk->sk_callback_lock);
0324 release_sock(sk);
0325 if (ctx->tx_conf == TLS_SW)
0326 tls_sw_free_ctx_tx(ctx);
0327 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
0328 tls_sw_strparser_done(ctx);
0329 if (ctx->rx_conf == TLS_SW)
0330 tls_sw_free_ctx_rx(ctx);
0331 ctx->sk_proto->close(sk, timeout);
0332
0333 if (free_ctx)
0334 tls_ctx_free(sk, ctx);
0335 }
0336
0337 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
0338 int __user *optlen, int tx)
0339 {
0340 int rc = 0;
0341 struct tls_context *ctx = tls_get_ctx(sk);
0342 struct tls_crypto_info *crypto_info;
0343 struct cipher_context *cctx;
0344 int len;
0345
0346 if (get_user(len, optlen))
0347 return -EFAULT;
0348
0349 if (!optval || (len < sizeof(*crypto_info))) {
0350 rc = -EINVAL;
0351 goto out;
0352 }
0353
0354 if (!ctx) {
0355 rc = -EBUSY;
0356 goto out;
0357 }
0358
0359
0360 if (tx) {
0361 crypto_info = &ctx->crypto_send.info;
0362 cctx = &ctx->tx;
0363 } else {
0364 crypto_info = &ctx->crypto_recv.info;
0365 cctx = &ctx->rx;
0366 }
0367
0368 if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
0369 rc = -EBUSY;
0370 goto out;
0371 }
0372
0373 if (len == sizeof(*crypto_info)) {
0374 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
0375 rc = -EFAULT;
0376 goto out;
0377 }
0378
0379 switch (crypto_info->cipher_type) {
0380 case TLS_CIPHER_AES_GCM_128: {
0381 struct tls12_crypto_info_aes_gcm_128 *
0382 crypto_info_aes_gcm_128 =
0383 container_of(crypto_info,
0384 struct tls12_crypto_info_aes_gcm_128,
0385 info);
0386
0387 if (len != sizeof(*crypto_info_aes_gcm_128)) {
0388 rc = -EINVAL;
0389 goto out;
0390 }
0391 lock_sock(sk);
0392 memcpy(crypto_info_aes_gcm_128->iv,
0393 cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
0394 TLS_CIPHER_AES_GCM_128_IV_SIZE);
0395 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
0396 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
0397 release_sock(sk);
0398 if (copy_to_user(optval,
0399 crypto_info_aes_gcm_128,
0400 sizeof(*crypto_info_aes_gcm_128)))
0401 rc = -EFAULT;
0402 break;
0403 }
0404 case TLS_CIPHER_AES_GCM_256: {
0405 struct tls12_crypto_info_aes_gcm_256 *
0406 crypto_info_aes_gcm_256 =
0407 container_of(crypto_info,
0408 struct tls12_crypto_info_aes_gcm_256,
0409 info);
0410
0411 if (len != sizeof(*crypto_info_aes_gcm_256)) {
0412 rc = -EINVAL;
0413 goto out;
0414 }
0415 lock_sock(sk);
0416 memcpy(crypto_info_aes_gcm_256->iv,
0417 cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
0418 TLS_CIPHER_AES_GCM_256_IV_SIZE);
0419 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
0420 TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
0421 release_sock(sk);
0422 if (copy_to_user(optval,
0423 crypto_info_aes_gcm_256,
0424 sizeof(*crypto_info_aes_gcm_256)))
0425 rc = -EFAULT;
0426 break;
0427 }
0428 case TLS_CIPHER_AES_CCM_128: {
0429 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
0430 container_of(crypto_info,
0431 struct tls12_crypto_info_aes_ccm_128, info);
0432
0433 if (len != sizeof(*aes_ccm_128)) {
0434 rc = -EINVAL;
0435 goto out;
0436 }
0437 lock_sock(sk);
0438 memcpy(aes_ccm_128->iv,
0439 cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
0440 TLS_CIPHER_AES_CCM_128_IV_SIZE);
0441 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
0442 TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
0443 release_sock(sk);
0444 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
0445 rc = -EFAULT;
0446 break;
0447 }
0448 case TLS_CIPHER_CHACHA20_POLY1305: {
0449 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
0450 container_of(crypto_info,
0451 struct tls12_crypto_info_chacha20_poly1305,
0452 info);
0453
0454 if (len != sizeof(*chacha20_poly1305)) {
0455 rc = -EINVAL;
0456 goto out;
0457 }
0458 lock_sock(sk);
0459 memcpy(chacha20_poly1305->iv,
0460 cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
0461 TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
0462 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
0463 TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
0464 release_sock(sk);
0465 if (copy_to_user(optval, chacha20_poly1305,
0466 sizeof(*chacha20_poly1305)))
0467 rc = -EFAULT;
0468 break;
0469 }
0470 case TLS_CIPHER_SM4_GCM: {
0471 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
0472 container_of(crypto_info,
0473 struct tls12_crypto_info_sm4_gcm, info);
0474
0475 if (len != sizeof(*sm4_gcm_info)) {
0476 rc = -EINVAL;
0477 goto out;
0478 }
0479 lock_sock(sk);
0480 memcpy(sm4_gcm_info->iv,
0481 cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
0482 TLS_CIPHER_SM4_GCM_IV_SIZE);
0483 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
0484 TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
0485 release_sock(sk);
0486 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
0487 rc = -EFAULT;
0488 break;
0489 }
0490 case TLS_CIPHER_SM4_CCM: {
0491 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
0492 container_of(crypto_info,
0493 struct tls12_crypto_info_sm4_ccm, info);
0494
0495 if (len != sizeof(*sm4_ccm_info)) {
0496 rc = -EINVAL;
0497 goto out;
0498 }
0499 lock_sock(sk);
0500 memcpy(sm4_ccm_info->iv,
0501 cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
0502 TLS_CIPHER_SM4_CCM_IV_SIZE);
0503 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
0504 TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
0505 release_sock(sk);
0506 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
0507 rc = -EFAULT;
0508 break;
0509 }
0510 default:
0511 rc = -EINVAL;
0512 }
0513
0514 out:
0515 return rc;
0516 }
0517
0518 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
0519 int __user *optlen)
0520 {
0521 struct tls_context *ctx = tls_get_ctx(sk);
0522 unsigned int value;
0523 int len;
0524
0525 if (get_user(len, optlen))
0526 return -EFAULT;
0527
0528 if (len != sizeof(value))
0529 return -EINVAL;
0530
0531 value = ctx->zerocopy_sendfile;
0532 if (copy_to_user(optval, &value, sizeof(value)))
0533 return -EFAULT;
0534
0535 return 0;
0536 }
0537
0538 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
0539 int __user *optlen)
0540 {
0541 struct tls_context *ctx = tls_get_ctx(sk);
0542 int value, len;
0543
0544 if (ctx->prot_info.version != TLS_1_3_VERSION)
0545 return -EINVAL;
0546
0547 if (get_user(len, optlen))
0548 return -EFAULT;
0549 if (len < sizeof(value))
0550 return -EINVAL;
0551
0552 lock_sock(sk);
0553 value = -EINVAL;
0554 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
0555 value = ctx->rx_no_pad;
0556 release_sock(sk);
0557 if (value < 0)
0558 return value;
0559
0560 if (put_user(sizeof(value), optlen))
0561 return -EFAULT;
0562 if (copy_to_user(optval, &value, sizeof(value)))
0563 return -EFAULT;
0564
0565 return 0;
0566 }
0567
0568 static int do_tls_getsockopt(struct sock *sk, int optname,
0569 char __user *optval, int __user *optlen)
0570 {
0571 int rc = 0;
0572
0573 switch (optname) {
0574 case TLS_TX:
0575 case TLS_RX:
0576 rc = do_tls_getsockopt_conf(sk, optval, optlen,
0577 optname == TLS_TX);
0578 break;
0579 case TLS_TX_ZEROCOPY_RO:
0580 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
0581 break;
0582 case TLS_RX_EXPECT_NO_PAD:
0583 rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
0584 break;
0585 default:
0586 rc = -ENOPROTOOPT;
0587 break;
0588 }
0589 return rc;
0590 }
0591
0592 static int tls_getsockopt(struct sock *sk, int level, int optname,
0593 char __user *optval, int __user *optlen)
0594 {
0595 struct tls_context *ctx = tls_get_ctx(sk);
0596
0597 if (level != SOL_TLS)
0598 return ctx->sk_proto->getsockopt(sk, level,
0599 optname, optval, optlen);
0600
0601 return do_tls_getsockopt(sk, optname, optval, optlen);
0602 }
0603
0604 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
0605 unsigned int optlen, int tx)
0606 {
0607 struct tls_crypto_info *crypto_info;
0608 struct tls_crypto_info *alt_crypto_info;
0609 struct tls_context *ctx = tls_get_ctx(sk);
0610 size_t optsize;
0611 int rc = 0;
0612 int conf;
0613
0614 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
0615 return -EINVAL;
0616
0617 if (tx) {
0618 crypto_info = &ctx->crypto_send.info;
0619 alt_crypto_info = &ctx->crypto_recv.info;
0620 } else {
0621 crypto_info = &ctx->crypto_recv.info;
0622 alt_crypto_info = &ctx->crypto_send.info;
0623 }
0624
0625
0626 if (TLS_CRYPTO_INFO_READY(crypto_info))
0627 return -EBUSY;
0628
0629 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
0630 if (rc) {
0631 rc = -EFAULT;
0632 goto err_crypto_info;
0633 }
0634
0635
0636 if (crypto_info->version != TLS_1_2_VERSION &&
0637 crypto_info->version != TLS_1_3_VERSION) {
0638 rc = -EINVAL;
0639 goto err_crypto_info;
0640 }
0641
0642
0643 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
0644 if (alt_crypto_info->version != crypto_info->version ||
0645 alt_crypto_info->cipher_type != crypto_info->cipher_type) {
0646 rc = -EINVAL;
0647 goto err_crypto_info;
0648 }
0649 }
0650
0651 switch (crypto_info->cipher_type) {
0652 case TLS_CIPHER_AES_GCM_128:
0653 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
0654 break;
0655 case TLS_CIPHER_AES_GCM_256: {
0656 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
0657 break;
0658 }
0659 case TLS_CIPHER_AES_CCM_128:
0660 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
0661 break;
0662 case TLS_CIPHER_CHACHA20_POLY1305:
0663 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
0664 break;
0665 case TLS_CIPHER_SM4_GCM:
0666 optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
0667 break;
0668 case TLS_CIPHER_SM4_CCM:
0669 optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
0670 break;
0671 default:
0672 rc = -EINVAL;
0673 goto err_crypto_info;
0674 }
0675
0676 if (optlen != optsize) {
0677 rc = -EINVAL;
0678 goto err_crypto_info;
0679 }
0680
0681 rc = copy_from_sockptr_offset(crypto_info + 1, optval,
0682 sizeof(*crypto_info),
0683 optlen - sizeof(*crypto_info));
0684 if (rc) {
0685 rc = -EFAULT;
0686 goto err_crypto_info;
0687 }
0688
0689 if (tx) {
0690 rc = tls_set_device_offload(sk, ctx);
0691 conf = TLS_HW;
0692 if (!rc) {
0693 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
0694 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
0695 } else {
0696 rc = tls_set_sw_offload(sk, ctx, 1);
0697 if (rc)
0698 goto err_crypto_info;
0699 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
0700 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
0701 conf = TLS_SW;
0702 }
0703 } else {
0704 rc = tls_set_device_offload_rx(sk, ctx);
0705 conf = TLS_HW;
0706 if (!rc) {
0707 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
0708 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
0709 } else {
0710 rc = tls_set_sw_offload(sk, ctx, 0);
0711 if (rc)
0712 goto err_crypto_info;
0713 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
0714 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
0715 conf = TLS_SW;
0716 }
0717 tls_sw_strparser_arm(sk, ctx);
0718 }
0719
0720 if (tx)
0721 ctx->tx_conf = conf;
0722 else
0723 ctx->rx_conf = conf;
0724 update_sk_prot(sk, ctx);
0725 if (tx) {
0726 ctx->sk_write_space = sk->sk_write_space;
0727 sk->sk_write_space = tls_write_space;
0728 } else {
0729 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
0730
0731 tls_strp_check_rcv(&rx_ctx->strp);
0732 }
0733 return 0;
0734
0735 err_crypto_info:
0736 memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
0737 return rc;
0738 }
0739
0740 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
0741 unsigned int optlen)
0742 {
0743 struct tls_context *ctx = tls_get_ctx(sk);
0744 unsigned int value;
0745
0746 if (sockptr_is_null(optval) || optlen != sizeof(value))
0747 return -EINVAL;
0748
0749 if (copy_from_sockptr(&value, optval, sizeof(value)))
0750 return -EFAULT;
0751
0752 if (value > 1)
0753 return -EINVAL;
0754
0755 ctx->zerocopy_sendfile = value;
0756
0757 return 0;
0758 }
0759
0760 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
0761 unsigned int optlen)
0762 {
0763 struct tls_context *ctx = tls_get_ctx(sk);
0764 u32 val;
0765 int rc;
0766
0767 if (ctx->prot_info.version != TLS_1_3_VERSION ||
0768 sockptr_is_null(optval) || optlen < sizeof(val))
0769 return -EINVAL;
0770
0771 rc = copy_from_sockptr(&val, optval, sizeof(val));
0772 if (rc)
0773 return -EFAULT;
0774 if (val > 1)
0775 return -EINVAL;
0776 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
0777 if (rc < 1)
0778 return rc == 0 ? -EINVAL : rc;
0779
0780 lock_sock(sk);
0781 rc = -EINVAL;
0782 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
0783 ctx->rx_no_pad = val;
0784 tls_update_rx_zc_capable(ctx);
0785 rc = 0;
0786 }
0787 release_sock(sk);
0788
0789 return rc;
0790 }
0791
0792 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
0793 unsigned int optlen)
0794 {
0795 int rc = 0;
0796
0797 switch (optname) {
0798 case TLS_TX:
0799 case TLS_RX:
0800 lock_sock(sk);
0801 rc = do_tls_setsockopt_conf(sk, optval, optlen,
0802 optname == TLS_TX);
0803 release_sock(sk);
0804 break;
0805 case TLS_TX_ZEROCOPY_RO:
0806 lock_sock(sk);
0807 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
0808 release_sock(sk);
0809 break;
0810 case TLS_RX_EXPECT_NO_PAD:
0811 rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
0812 break;
0813 default:
0814 rc = -ENOPROTOOPT;
0815 break;
0816 }
0817 return rc;
0818 }
0819
0820 static int tls_setsockopt(struct sock *sk, int level, int optname,
0821 sockptr_t optval, unsigned int optlen)
0822 {
0823 struct tls_context *ctx = tls_get_ctx(sk);
0824
0825 if (level != SOL_TLS)
0826 return ctx->sk_proto->setsockopt(sk, level, optname, optval,
0827 optlen);
0828
0829 return do_tls_setsockopt(sk, optname, optval, optlen);
0830 }
0831
0832 struct tls_context *tls_ctx_create(struct sock *sk)
0833 {
0834 struct inet_connection_sock *icsk = inet_csk(sk);
0835 struct tls_context *ctx;
0836
0837 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
0838 if (!ctx)
0839 return NULL;
0840
0841 mutex_init(&ctx->tx_lock);
0842 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
0843 ctx->sk_proto = READ_ONCE(sk->sk_prot);
0844 ctx->sk = sk;
0845 return ctx;
0846 }
0847
0848 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
0849 const struct proto_ops *base)
0850 {
0851 ops[TLS_BASE][TLS_BASE] = *base;
0852
0853 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
0854 ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
0855
0856 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE];
0857 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read;
0858
0859 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE];
0860 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read;
0861
0862 #ifdef CONFIG_TLS_DEVICE
0863 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
0864 ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL;
0865
0866 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ];
0867 ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL;
0868
0869 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ];
0870
0871 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ];
0872
0873 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ];
0874 ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL;
0875 #endif
0876 #ifdef CONFIG_TLS_TOE
0877 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
0878 #endif
0879 }
0880
0881 static void tls_build_proto(struct sock *sk)
0882 {
0883 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
0884 struct proto *prot = READ_ONCE(sk->sk_prot);
0885
0886
0887 if (ip_ver == TLSV6 &&
0888 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
0889 mutex_lock(&tcpv6_prot_mutex);
0890 if (likely(prot != saved_tcpv6_prot)) {
0891 build_protos(tls_prots[TLSV6], prot);
0892 build_proto_ops(tls_proto_ops[TLSV6],
0893 sk->sk_socket->ops);
0894 smp_store_release(&saved_tcpv6_prot, prot);
0895 }
0896 mutex_unlock(&tcpv6_prot_mutex);
0897 }
0898
0899 if (ip_ver == TLSV4 &&
0900 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
0901 mutex_lock(&tcpv4_prot_mutex);
0902 if (likely(prot != saved_tcpv4_prot)) {
0903 build_protos(tls_prots[TLSV4], prot);
0904 build_proto_ops(tls_proto_ops[TLSV4],
0905 sk->sk_socket->ops);
0906 smp_store_release(&saved_tcpv4_prot, prot);
0907 }
0908 mutex_unlock(&tcpv4_prot_mutex);
0909 }
0910 }
0911
0912 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
0913 const struct proto *base)
0914 {
0915 prot[TLS_BASE][TLS_BASE] = *base;
0916 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
0917 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
0918 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
0919
0920 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
0921 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
0922 prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
0923
0924 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
0925 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
0926 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
0927 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
0928
0929 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
0930 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
0931 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
0932 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
0933
0934 #ifdef CONFIG_TLS_DEVICE
0935 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
0936 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
0937 prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
0938
0939 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
0940 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
0941 prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
0942
0943 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
0944
0945 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
0946
0947 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
0948 #endif
0949 #ifdef CONFIG_TLS_TOE
0950 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
0951 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash;
0952 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash;
0953 #endif
0954 }
0955
0956 static int tls_init(struct sock *sk)
0957 {
0958 struct tls_context *ctx;
0959 int rc = 0;
0960
0961 tls_build_proto(sk);
0962
0963 #ifdef CONFIG_TLS_TOE
0964 if (tls_toe_bypass(sk))
0965 return 0;
0966 #endif
0967
0968
0969
0970
0971
0972
0973
0974 if (sk->sk_state != TCP_ESTABLISHED)
0975 return -ENOTCONN;
0976
0977
0978 write_lock_bh(&sk->sk_callback_lock);
0979 ctx = tls_ctx_create(sk);
0980 if (!ctx) {
0981 rc = -ENOMEM;
0982 goto out;
0983 }
0984
0985 ctx->tx_conf = TLS_BASE;
0986 ctx->rx_conf = TLS_BASE;
0987 update_sk_prot(sk, ctx);
0988 out:
0989 write_unlock_bh(&sk->sk_callback_lock);
0990 return rc;
0991 }
0992
0993 static void tls_update(struct sock *sk, struct proto *p,
0994 void (*write_space)(struct sock *sk))
0995 {
0996 struct tls_context *ctx;
0997
0998 WARN_ON_ONCE(sk->sk_prot == p);
0999
1000 ctx = tls_get_ctx(sk);
1001 if (likely(ctx)) {
1002 ctx->sk_write_space = write_space;
1003 ctx->sk_proto = p;
1004 } else {
1005
1006 WRITE_ONCE(sk->sk_prot, p);
1007 sk->sk_write_space = write_space;
1008 }
1009 }
1010
1011 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1012 {
1013 u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1014
1015 switch (config) {
1016 case TLS_BASE:
1017 return TLS_CONF_BASE;
1018 case TLS_SW:
1019 return TLS_CONF_SW;
1020 case TLS_HW:
1021 return TLS_CONF_HW;
1022 case TLS_HW_RECORD:
1023 return TLS_CONF_HW_RECORD;
1024 }
1025 return 0;
1026 }
1027
1028 static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
1029 {
1030 u16 version, cipher_type;
1031 struct tls_context *ctx;
1032 struct nlattr *start;
1033 int err;
1034
1035 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1036 if (!start)
1037 return -EMSGSIZE;
1038
1039 rcu_read_lock();
1040 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1041 if (!ctx) {
1042 err = 0;
1043 goto nla_failure;
1044 }
1045 version = ctx->prot_info.version;
1046 if (version) {
1047 err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1048 if (err)
1049 goto nla_failure;
1050 }
1051 cipher_type = ctx->prot_info.cipher_type;
1052 if (cipher_type) {
1053 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1054 if (err)
1055 goto nla_failure;
1056 }
1057 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1058 if (err)
1059 goto nla_failure;
1060
1061 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1062 if (err)
1063 goto nla_failure;
1064
1065 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1066 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1067 if (err)
1068 goto nla_failure;
1069 }
1070 if (ctx->rx_no_pad) {
1071 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1072 if (err)
1073 goto nla_failure;
1074 }
1075
1076 rcu_read_unlock();
1077 nla_nest_end(skb, start);
1078 return 0;
1079
1080 nla_failure:
1081 rcu_read_unlock();
1082 nla_nest_cancel(skb, start);
1083 return err;
1084 }
1085
1086 static size_t tls_get_info_size(const struct sock *sk)
1087 {
1088 size_t size = 0;
1089
1090 size += nla_total_size(0) +
1091 nla_total_size(sizeof(u16)) +
1092 nla_total_size(sizeof(u16)) +
1093 nla_total_size(sizeof(u16)) +
1094 nla_total_size(sizeof(u16)) +
1095 nla_total_size(0) +
1096 nla_total_size(0) +
1097 0;
1098
1099 return size;
1100 }
1101
1102 static int __net_init tls_init_net(struct net *net)
1103 {
1104 int err;
1105
1106 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1107 if (!net->mib.tls_statistics)
1108 return -ENOMEM;
1109
1110 err = tls_proc_init(net);
1111 if (err)
1112 goto err_free_stats;
1113
1114 return 0;
1115 err_free_stats:
1116 free_percpu(net->mib.tls_statistics);
1117 return err;
1118 }
1119
1120 static void __net_exit tls_exit_net(struct net *net)
1121 {
1122 tls_proc_fini(net);
1123 free_percpu(net->mib.tls_statistics);
1124 }
1125
1126 static struct pernet_operations tls_proc_ops = {
1127 .init = tls_init_net,
1128 .exit = tls_exit_net,
1129 };
1130
1131 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1132 .name = "tls",
1133 .owner = THIS_MODULE,
1134 .init = tls_init,
1135 .update = tls_update,
1136 .get_info = tls_get_info,
1137 .get_info_size = tls_get_info_size,
1138 };
1139
1140 static int __init tls_register(void)
1141 {
1142 int err;
1143
1144 err = register_pernet_subsys(&tls_proc_ops);
1145 if (err)
1146 return err;
1147
1148 err = tls_strp_dev_init();
1149 if (err)
1150 goto err_pernet;
1151
1152 err = tls_device_init();
1153 if (err)
1154 goto err_strp;
1155
1156 tcp_register_ulp(&tcp_tls_ulp_ops);
1157
1158 return 0;
1159 err_strp:
1160 tls_strp_dev_exit();
1161 err_pernet:
1162 unregister_pernet_subsys(&tls_proc_ops);
1163 return err;
1164 }
1165
1166 static void __exit tls_unregister(void)
1167 {
1168 tcp_unregister_ulp(&tcp_tls_ulp_ops);
1169 tls_strp_dev_exit();
1170 tls_device_cleanup();
1171 unregister_pernet_subsys(&tls_proc_ops);
1172 }
1173
1174 module_init(tls_register);
1175 module_exit(tls_unregister);