Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
0003  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
0004  *
0005  * This software is available to you under a choice of one of two
0006  * licenses.  You may choose to be licensed under the terms of the GNU
0007  * General Public License (GPL) Version 2, available from the file
0008  * COPYING in the main directory of this source tree, or the
0009  * OpenIB.org BSD license below:
0010  *
0011  *     Redistribution and use in source and binary forms, with or
0012  *     without modification, are permitted provided that the following
0013  *     conditions are met:
0014  *
0015  *      - Redistributions of source code must retain the above
0016  *        copyright notice, this list of conditions and the following
0017  *        disclaimer.
0018  *
0019  *      - Redistributions in binary form must reproduce the above
0020  *        copyright notice, this list of conditions and the following
0021  *        disclaimer in the documentation and/or other materials
0022  *        provided with the distribution.
0023  *
0024  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
0025  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
0026  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
0027  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
0028  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
0029  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
0030  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
0031  * SOFTWARE.
0032  */
0033 
0034 #include <linux/module.h>
0035 
0036 #include <net/tcp.h>
0037 #include <net/inet_common.h>
0038 #include <linux/highmem.h>
0039 #include <linux/netdevice.h>
0040 #include <linux/sched/signal.h>
0041 #include <linux/inetdevice.h>
0042 #include <linux/inet_diag.h>
0043 
0044 #include <net/snmp.h>
0045 #include <net/tls.h>
0046 #include <net/tls_toe.h>
0047 
0048 #include "tls.h"
0049 
0050 MODULE_AUTHOR("Mellanox Technologies");
0051 MODULE_DESCRIPTION("Transport Layer Security Support");
0052 MODULE_LICENSE("Dual BSD/GPL");
0053 MODULE_ALIAS_TCP_ULP("tls");
0054 
0055 enum {
0056     TLSV4,
0057     TLSV6,
0058     TLS_NUM_PROTS,
0059 };
0060 
0061 static const struct proto *saved_tcpv6_prot;
0062 static DEFINE_MUTEX(tcpv6_prot_mutex);
0063 static const struct proto *saved_tcpv4_prot;
0064 static DEFINE_MUTEX(tcpv4_prot_mutex);
0065 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
0066 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
0067 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
0068              const struct proto *base);
0069 
0070 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
0071 {
0072     int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
0073 
0074     WRITE_ONCE(sk->sk_prot,
0075            &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
0076     WRITE_ONCE(sk->sk_socket->ops,
0077            &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
0078 }
0079 
0080 int wait_on_pending_writer(struct sock *sk, long *timeo)
0081 {
0082     int rc = 0;
0083     DEFINE_WAIT_FUNC(wait, woken_wake_function);
0084 
0085     add_wait_queue(sk_sleep(sk), &wait);
0086     while (1) {
0087         if (!*timeo) {
0088             rc = -EAGAIN;
0089             break;
0090         }
0091 
0092         if (signal_pending(current)) {
0093             rc = sock_intr_errno(*timeo);
0094             break;
0095         }
0096 
0097         if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
0098             break;
0099     }
0100     remove_wait_queue(sk_sleep(sk), &wait);
0101     return rc;
0102 }
0103 
0104 int tls_push_sg(struct sock *sk,
0105         struct tls_context *ctx,
0106         struct scatterlist *sg,
0107         u16 first_offset,
0108         int flags)
0109 {
0110     int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
0111     int ret = 0;
0112     struct page *p;
0113     size_t size;
0114     int offset = first_offset;
0115 
0116     size = sg->length - offset;
0117     offset += sg->offset;
0118 
0119     ctx->in_tcp_sendpages = true;
0120     while (1) {
0121         if (sg_is_last(sg))
0122             sendpage_flags = flags;
0123 
0124         /* is sending application-limited? */
0125         tcp_rate_check_app_limited(sk);
0126         p = sg_page(sg);
0127 retry:
0128         ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
0129 
0130         if (ret != size) {
0131             if (ret > 0) {
0132                 offset += ret;
0133                 size -= ret;
0134                 goto retry;
0135             }
0136 
0137             offset -= sg->offset;
0138             ctx->partially_sent_offset = offset;
0139             ctx->partially_sent_record = (void *)sg;
0140             ctx->in_tcp_sendpages = false;
0141             return ret;
0142         }
0143 
0144         put_page(p);
0145         sk_mem_uncharge(sk, sg->length);
0146         sg = sg_next(sg);
0147         if (!sg)
0148             break;
0149 
0150         offset = sg->offset;
0151         size = sg->length;
0152     }
0153 
0154     ctx->in_tcp_sendpages = false;
0155 
0156     return 0;
0157 }
0158 
0159 static int tls_handle_open_record(struct sock *sk, int flags)
0160 {
0161     struct tls_context *ctx = tls_get_ctx(sk);
0162 
0163     if (tls_is_pending_open_record(ctx))
0164         return ctx->push_pending_record(sk, flags);
0165 
0166     return 0;
0167 }
0168 
0169 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
0170              unsigned char *record_type)
0171 {
0172     struct cmsghdr *cmsg;
0173     int rc = -EINVAL;
0174 
0175     for_each_cmsghdr(cmsg, msg) {
0176         if (!CMSG_OK(msg, cmsg))
0177             return -EINVAL;
0178         if (cmsg->cmsg_level != SOL_TLS)
0179             continue;
0180 
0181         switch (cmsg->cmsg_type) {
0182         case TLS_SET_RECORD_TYPE:
0183             if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
0184                 return -EINVAL;
0185 
0186             if (msg->msg_flags & MSG_MORE)
0187                 return -EINVAL;
0188 
0189             rc = tls_handle_open_record(sk, msg->msg_flags);
0190             if (rc)
0191                 return rc;
0192 
0193             *record_type = *(unsigned char *)CMSG_DATA(cmsg);
0194             rc = 0;
0195             break;
0196         default:
0197             return -EINVAL;
0198         }
0199     }
0200 
0201     return rc;
0202 }
0203 
0204 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
0205                 int flags)
0206 {
0207     struct scatterlist *sg;
0208     u16 offset;
0209 
0210     sg = ctx->partially_sent_record;
0211     offset = ctx->partially_sent_offset;
0212 
0213     ctx->partially_sent_record = NULL;
0214     return tls_push_sg(sk, ctx, sg, offset, flags);
0215 }
0216 
0217 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
0218 {
0219     struct scatterlist *sg;
0220 
0221     for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
0222         put_page(sg_page(sg));
0223         sk_mem_uncharge(sk, sg->length);
0224     }
0225     ctx->partially_sent_record = NULL;
0226 }
0227 
0228 static void tls_write_space(struct sock *sk)
0229 {
0230     struct tls_context *ctx = tls_get_ctx(sk);
0231 
0232     /* If in_tcp_sendpages call lower protocol write space handler
0233      * to ensure we wake up any waiting operations there. For example
0234      * if do_tcp_sendpages where to call sk_wait_event.
0235      */
0236     if (ctx->in_tcp_sendpages) {
0237         ctx->sk_write_space(sk);
0238         return;
0239     }
0240 
0241 #ifdef CONFIG_TLS_DEVICE
0242     if (ctx->tx_conf == TLS_HW)
0243         tls_device_write_space(sk, ctx);
0244     else
0245 #endif
0246         tls_sw_write_space(sk, ctx);
0247 
0248     ctx->sk_write_space(sk);
0249 }
0250 
0251 /**
0252  * tls_ctx_free() - free TLS ULP context
0253  * @sk:  socket to with @ctx is attached
0254  * @ctx: TLS context structure
0255  *
0256  * Free TLS context. If @sk is %NULL caller guarantees that the socket
0257  * to which @ctx was attached has no outstanding references.
0258  */
0259 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
0260 {
0261     if (!ctx)
0262         return;
0263 
0264     memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
0265     memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
0266     mutex_destroy(&ctx->tx_lock);
0267 
0268     if (sk)
0269         kfree_rcu(ctx, rcu);
0270     else
0271         kfree(ctx);
0272 }
0273 
0274 static void tls_sk_proto_cleanup(struct sock *sk,
0275                  struct tls_context *ctx, long timeo)
0276 {
0277     if (unlikely(sk->sk_write_pending) &&
0278         !wait_on_pending_writer(sk, &timeo))
0279         tls_handle_open_record(sk, 0);
0280 
0281     /* We need these for tls_sw_fallback handling of other packets */
0282     if (ctx->tx_conf == TLS_SW) {
0283         kfree(ctx->tx.rec_seq);
0284         kfree(ctx->tx.iv);
0285         tls_sw_release_resources_tx(sk);
0286         TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
0287     } else if (ctx->tx_conf == TLS_HW) {
0288         tls_device_free_resources_tx(sk);
0289         TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
0290     }
0291 
0292     if (ctx->rx_conf == TLS_SW) {
0293         tls_sw_release_resources_rx(sk);
0294         TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
0295     } else if (ctx->rx_conf == TLS_HW) {
0296         tls_device_offload_cleanup_rx(sk);
0297         TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
0298     }
0299 }
0300 
0301 static void tls_sk_proto_close(struct sock *sk, long timeout)
0302 {
0303     struct inet_connection_sock *icsk = inet_csk(sk);
0304     struct tls_context *ctx = tls_get_ctx(sk);
0305     long timeo = sock_sndtimeo(sk, 0);
0306     bool free_ctx;
0307 
0308     if (ctx->tx_conf == TLS_SW)
0309         tls_sw_cancel_work_tx(ctx);
0310 
0311     lock_sock(sk);
0312     free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
0313 
0314     if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
0315         tls_sk_proto_cleanup(sk, ctx, timeo);
0316 
0317     write_lock_bh(&sk->sk_callback_lock);
0318     if (free_ctx)
0319         rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
0320     WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
0321     if (sk->sk_write_space == tls_write_space)
0322         sk->sk_write_space = ctx->sk_write_space;
0323     write_unlock_bh(&sk->sk_callback_lock);
0324     release_sock(sk);
0325     if (ctx->tx_conf == TLS_SW)
0326         tls_sw_free_ctx_tx(ctx);
0327     if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
0328         tls_sw_strparser_done(ctx);
0329     if (ctx->rx_conf == TLS_SW)
0330         tls_sw_free_ctx_rx(ctx);
0331     ctx->sk_proto->close(sk, timeout);
0332 
0333     if (free_ctx)
0334         tls_ctx_free(sk, ctx);
0335 }
0336 
0337 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
0338                   int __user *optlen, int tx)
0339 {
0340     int rc = 0;
0341     struct tls_context *ctx = tls_get_ctx(sk);
0342     struct tls_crypto_info *crypto_info;
0343     struct cipher_context *cctx;
0344     int len;
0345 
0346     if (get_user(len, optlen))
0347         return -EFAULT;
0348 
0349     if (!optval || (len < sizeof(*crypto_info))) {
0350         rc = -EINVAL;
0351         goto out;
0352     }
0353 
0354     if (!ctx) {
0355         rc = -EBUSY;
0356         goto out;
0357     }
0358 
0359     /* get user crypto info */
0360     if (tx) {
0361         crypto_info = &ctx->crypto_send.info;
0362         cctx = &ctx->tx;
0363     } else {
0364         crypto_info = &ctx->crypto_recv.info;
0365         cctx = &ctx->rx;
0366     }
0367 
0368     if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
0369         rc = -EBUSY;
0370         goto out;
0371     }
0372 
0373     if (len == sizeof(*crypto_info)) {
0374         if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
0375             rc = -EFAULT;
0376         goto out;
0377     }
0378 
0379     switch (crypto_info->cipher_type) {
0380     case TLS_CIPHER_AES_GCM_128: {
0381         struct tls12_crypto_info_aes_gcm_128 *
0382           crypto_info_aes_gcm_128 =
0383           container_of(crypto_info,
0384                    struct tls12_crypto_info_aes_gcm_128,
0385                    info);
0386 
0387         if (len != sizeof(*crypto_info_aes_gcm_128)) {
0388             rc = -EINVAL;
0389             goto out;
0390         }
0391         lock_sock(sk);
0392         memcpy(crypto_info_aes_gcm_128->iv,
0393                cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
0394                TLS_CIPHER_AES_GCM_128_IV_SIZE);
0395         memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
0396                TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
0397         release_sock(sk);
0398         if (copy_to_user(optval,
0399                  crypto_info_aes_gcm_128,
0400                  sizeof(*crypto_info_aes_gcm_128)))
0401             rc = -EFAULT;
0402         break;
0403     }
0404     case TLS_CIPHER_AES_GCM_256: {
0405         struct tls12_crypto_info_aes_gcm_256 *
0406           crypto_info_aes_gcm_256 =
0407           container_of(crypto_info,
0408                    struct tls12_crypto_info_aes_gcm_256,
0409                    info);
0410 
0411         if (len != sizeof(*crypto_info_aes_gcm_256)) {
0412             rc = -EINVAL;
0413             goto out;
0414         }
0415         lock_sock(sk);
0416         memcpy(crypto_info_aes_gcm_256->iv,
0417                cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
0418                TLS_CIPHER_AES_GCM_256_IV_SIZE);
0419         memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
0420                TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
0421         release_sock(sk);
0422         if (copy_to_user(optval,
0423                  crypto_info_aes_gcm_256,
0424                  sizeof(*crypto_info_aes_gcm_256)))
0425             rc = -EFAULT;
0426         break;
0427     }
0428     case TLS_CIPHER_AES_CCM_128: {
0429         struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
0430             container_of(crypto_info,
0431                 struct tls12_crypto_info_aes_ccm_128, info);
0432 
0433         if (len != sizeof(*aes_ccm_128)) {
0434             rc = -EINVAL;
0435             goto out;
0436         }
0437         lock_sock(sk);
0438         memcpy(aes_ccm_128->iv,
0439                cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
0440                TLS_CIPHER_AES_CCM_128_IV_SIZE);
0441         memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
0442                TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
0443         release_sock(sk);
0444         if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
0445             rc = -EFAULT;
0446         break;
0447     }
0448     case TLS_CIPHER_CHACHA20_POLY1305: {
0449         struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
0450             container_of(crypto_info,
0451                 struct tls12_crypto_info_chacha20_poly1305,
0452                 info);
0453 
0454         if (len != sizeof(*chacha20_poly1305)) {
0455             rc = -EINVAL;
0456             goto out;
0457         }
0458         lock_sock(sk);
0459         memcpy(chacha20_poly1305->iv,
0460                cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
0461                TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
0462         memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
0463                TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
0464         release_sock(sk);
0465         if (copy_to_user(optval, chacha20_poly1305,
0466                 sizeof(*chacha20_poly1305)))
0467             rc = -EFAULT;
0468         break;
0469     }
0470     case TLS_CIPHER_SM4_GCM: {
0471         struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
0472             container_of(crypto_info,
0473                 struct tls12_crypto_info_sm4_gcm, info);
0474 
0475         if (len != sizeof(*sm4_gcm_info)) {
0476             rc = -EINVAL;
0477             goto out;
0478         }
0479         lock_sock(sk);
0480         memcpy(sm4_gcm_info->iv,
0481                cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
0482                TLS_CIPHER_SM4_GCM_IV_SIZE);
0483         memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
0484                TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
0485         release_sock(sk);
0486         if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
0487             rc = -EFAULT;
0488         break;
0489     }
0490     case TLS_CIPHER_SM4_CCM: {
0491         struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
0492             container_of(crypto_info,
0493                 struct tls12_crypto_info_sm4_ccm, info);
0494 
0495         if (len != sizeof(*sm4_ccm_info)) {
0496             rc = -EINVAL;
0497             goto out;
0498         }
0499         lock_sock(sk);
0500         memcpy(sm4_ccm_info->iv,
0501                cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
0502                TLS_CIPHER_SM4_CCM_IV_SIZE);
0503         memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
0504                TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
0505         release_sock(sk);
0506         if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
0507             rc = -EFAULT;
0508         break;
0509     }
0510     default:
0511         rc = -EINVAL;
0512     }
0513 
0514 out:
0515     return rc;
0516 }
0517 
0518 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
0519                    int __user *optlen)
0520 {
0521     struct tls_context *ctx = tls_get_ctx(sk);
0522     unsigned int value;
0523     int len;
0524 
0525     if (get_user(len, optlen))
0526         return -EFAULT;
0527 
0528     if (len != sizeof(value))
0529         return -EINVAL;
0530 
0531     value = ctx->zerocopy_sendfile;
0532     if (copy_to_user(optval, &value, sizeof(value)))
0533         return -EFAULT;
0534 
0535     return 0;
0536 }
0537 
0538 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
0539                     int __user *optlen)
0540 {
0541     struct tls_context *ctx = tls_get_ctx(sk);
0542     int value, len;
0543 
0544     if (ctx->prot_info.version != TLS_1_3_VERSION)
0545         return -EINVAL;
0546 
0547     if (get_user(len, optlen))
0548         return -EFAULT;
0549     if (len < sizeof(value))
0550         return -EINVAL;
0551 
0552     lock_sock(sk);
0553     value = -EINVAL;
0554     if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
0555         value = ctx->rx_no_pad;
0556     release_sock(sk);
0557     if (value < 0)
0558         return value;
0559 
0560     if (put_user(sizeof(value), optlen))
0561         return -EFAULT;
0562     if (copy_to_user(optval, &value, sizeof(value)))
0563         return -EFAULT;
0564 
0565     return 0;
0566 }
0567 
0568 static int do_tls_getsockopt(struct sock *sk, int optname,
0569                  char __user *optval, int __user *optlen)
0570 {
0571     int rc = 0;
0572 
0573     switch (optname) {
0574     case TLS_TX:
0575     case TLS_RX:
0576         rc = do_tls_getsockopt_conf(sk, optval, optlen,
0577                         optname == TLS_TX);
0578         break;
0579     case TLS_TX_ZEROCOPY_RO:
0580         rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
0581         break;
0582     case TLS_RX_EXPECT_NO_PAD:
0583         rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
0584         break;
0585     default:
0586         rc = -ENOPROTOOPT;
0587         break;
0588     }
0589     return rc;
0590 }
0591 
0592 static int tls_getsockopt(struct sock *sk, int level, int optname,
0593               char __user *optval, int __user *optlen)
0594 {
0595     struct tls_context *ctx = tls_get_ctx(sk);
0596 
0597     if (level != SOL_TLS)
0598         return ctx->sk_proto->getsockopt(sk, level,
0599                          optname, optval, optlen);
0600 
0601     return do_tls_getsockopt(sk, optname, optval, optlen);
0602 }
0603 
0604 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
0605                   unsigned int optlen, int tx)
0606 {
0607     struct tls_crypto_info *crypto_info;
0608     struct tls_crypto_info *alt_crypto_info;
0609     struct tls_context *ctx = tls_get_ctx(sk);
0610     size_t optsize;
0611     int rc = 0;
0612     int conf;
0613 
0614     if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
0615         return -EINVAL;
0616 
0617     if (tx) {
0618         crypto_info = &ctx->crypto_send.info;
0619         alt_crypto_info = &ctx->crypto_recv.info;
0620     } else {
0621         crypto_info = &ctx->crypto_recv.info;
0622         alt_crypto_info = &ctx->crypto_send.info;
0623     }
0624 
0625     /* Currently we don't support set crypto info more than one time */
0626     if (TLS_CRYPTO_INFO_READY(crypto_info))
0627         return -EBUSY;
0628 
0629     rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
0630     if (rc) {
0631         rc = -EFAULT;
0632         goto err_crypto_info;
0633     }
0634 
0635     /* check version */
0636     if (crypto_info->version != TLS_1_2_VERSION &&
0637         crypto_info->version != TLS_1_3_VERSION) {
0638         rc = -EINVAL;
0639         goto err_crypto_info;
0640     }
0641 
0642     /* Ensure that TLS version and ciphers are same in both directions */
0643     if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
0644         if (alt_crypto_info->version != crypto_info->version ||
0645             alt_crypto_info->cipher_type != crypto_info->cipher_type) {
0646             rc = -EINVAL;
0647             goto err_crypto_info;
0648         }
0649     }
0650 
0651     switch (crypto_info->cipher_type) {
0652     case TLS_CIPHER_AES_GCM_128:
0653         optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
0654         break;
0655     case TLS_CIPHER_AES_GCM_256: {
0656         optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
0657         break;
0658     }
0659     case TLS_CIPHER_AES_CCM_128:
0660         optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
0661         break;
0662     case TLS_CIPHER_CHACHA20_POLY1305:
0663         optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
0664         break;
0665     case TLS_CIPHER_SM4_GCM:
0666         optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
0667         break;
0668     case TLS_CIPHER_SM4_CCM:
0669         optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
0670         break;
0671     default:
0672         rc = -EINVAL;
0673         goto err_crypto_info;
0674     }
0675 
0676     if (optlen != optsize) {
0677         rc = -EINVAL;
0678         goto err_crypto_info;
0679     }
0680 
0681     rc = copy_from_sockptr_offset(crypto_info + 1, optval,
0682                       sizeof(*crypto_info),
0683                       optlen - sizeof(*crypto_info));
0684     if (rc) {
0685         rc = -EFAULT;
0686         goto err_crypto_info;
0687     }
0688 
0689     if (tx) {
0690         rc = tls_set_device_offload(sk, ctx);
0691         conf = TLS_HW;
0692         if (!rc) {
0693             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
0694             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
0695         } else {
0696             rc = tls_set_sw_offload(sk, ctx, 1);
0697             if (rc)
0698                 goto err_crypto_info;
0699             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
0700             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
0701             conf = TLS_SW;
0702         }
0703     } else {
0704         rc = tls_set_device_offload_rx(sk, ctx);
0705         conf = TLS_HW;
0706         if (!rc) {
0707             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
0708             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
0709         } else {
0710             rc = tls_set_sw_offload(sk, ctx, 0);
0711             if (rc)
0712                 goto err_crypto_info;
0713             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
0714             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
0715             conf = TLS_SW;
0716         }
0717         tls_sw_strparser_arm(sk, ctx);
0718     }
0719 
0720     if (tx)
0721         ctx->tx_conf = conf;
0722     else
0723         ctx->rx_conf = conf;
0724     update_sk_prot(sk, ctx);
0725     if (tx) {
0726         ctx->sk_write_space = sk->sk_write_space;
0727         sk->sk_write_space = tls_write_space;
0728     } else {
0729         struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
0730 
0731         tls_strp_check_rcv(&rx_ctx->strp);
0732     }
0733     return 0;
0734 
0735 err_crypto_info:
0736     memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
0737     return rc;
0738 }
0739 
0740 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
0741                    unsigned int optlen)
0742 {
0743     struct tls_context *ctx = tls_get_ctx(sk);
0744     unsigned int value;
0745 
0746     if (sockptr_is_null(optval) || optlen != sizeof(value))
0747         return -EINVAL;
0748 
0749     if (copy_from_sockptr(&value, optval, sizeof(value)))
0750         return -EFAULT;
0751 
0752     if (value > 1)
0753         return -EINVAL;
0754 
0755     ctx->zerocopy_sendfile = value;
0756 
0757     return 0;
0758 }
0759 
0760 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
0761                     unsigned int optlen)
0762 {
0763     struct tls_context *ctx = tls_get_ctx(sk);
0764     u32 val;
0765     int rc;
0766 
0767     if (ctx->prot_info.version != TLS_1_3_VERSION ||
0768         sockptr_is_null(optval) || optlen < sizeof(val))
0769         return -EINVAL;
0770 
0771     rc = copy_from_sockptr(&val, optval, sizeof(val));
0772     if (rc)
0773         return -EFAULT;
0774     if (val > 1)
0775         return -EINVAL;
0776     rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
0777     if (rc < 1)
0778         return rc == 0 ? -EINVAL : rc;
0779 
0780     lock_sock(sk);
0781     rc = -EINVAL;
0782     if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
0783         ctx->rx_no_pad = val;
0784         tls_update_rx_zc_capable(ctx);
0785         rc = 0;
0786     }
0787     release_sock(sk);
0788 
0789     return rc;
0790 }
0791 
0792 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
0793                  unsigned int optlen)
0794 {
0795     int rc = 0;
0796 
0797     switch (optname) {
0798     case TLS_TX:
0799     case TLS_RX:
0800         lock_sock(sk);
0801         rc = do_tls_setsockopt_conf(sk, optval, optlen,
0802                         optname == TLS_TX);
0803         release_sock(sk);
0804         break;
0805     case TLS_TX_ZEROCOPY_RO:
0806         lock_sock(sk);
0807         rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
0808         release_sock(sk);
0809         break;
0810     case TLS_RX_EXPECT_NO_PAD:
0811         rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
0812         break;
0813     default:
0814         rc = -ENOPROTOOPT;
0815         break;
0816     }
0817     return rc;
0818 }
0819 
0820 static int tls_setsockopt(struct sock *sk, int level, int optname,
0821               sockptr_t optval, unsigned int optlen)
0822 {
0823     struct tls_context *ctx = tls_get_ctx(sk);
0824 
0825     if (level != SOL_TLS)
0826         return ctx->sk_proto->setsockopt(sk, level, optname, optval,
0827                          optlen);
0828 
0829     return do_tls_setsockopt(sk, optname, optval, optlen);
0830 }
0831 
0832 struct tls_context *tls_ctx_create(struct sock *sk)
0833 {
0834     struct inet_connection_sock *icsk = inet_csk(sk);
0835     struct tls_context *ctx;
0836 
0837     ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
0838     if (!ctx)
0839         return NULL;
0840 
0841     mutex_init(&ctx->tx_lock);
0842     rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
0843     ctx->sk_proto = READ_ONCE(sk->sk_prot);
0844     ctx->sk = sk;
0845     return ctx;
0846 }
0847 
0848 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
0849                 const struct proto_ops *base)
0850 {
0851     ops[TLS_BASE][TLS_BASE] = *base;
0852 
0853     ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
0854     ops[TLS_SW  ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
0855 
0856     ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
0857     ops[TLS_BASE][TLS_SW  ].splice_read = tls_sw_splice_read;
0858 
0859     ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
0860     ops[TLS_SW  ][TLS_SW  ].splice_read = tls_sw_splice_read;
0861 
0862 #ifdef CONFIG_TLS_DEVICE
0863     ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
0864     ops[TLS_HW  ][TLS_BASE].sendpage_locked = NULL;
0865 
0866     ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
0867     ops[TLS_HW  ][TLS_SW  ].sendpage_locked = NULL;
0868 
0869     ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
0870 
0871     ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
0872 
0873     ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
0874     ops[TLS_HW  ][TLS_HW  ].sendpage_locked = NULL;
0875 #endif
0876 #ifdef CONFIG_TLS_TOE
0877     ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
0878 #endif
0879 }
0880 
0881 static void tls_build_proto(struct sock *sk)
0882 {
0883     int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
0884     struct proto *prot = READ_ONCE(sk->sk_prot);
0885 
0886     /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
0887     if (ip_ver == TLSV6 &&
0888         unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
0889         mutex_lock(&tcpv6_prot_mutex);
0890         if (likely(prot != saved_tcpv6_prot)) {
0891             build_protos(tls_prots[TLSV6], prot);
0892             build_proto_ops(tls_proto_ops[TLSV6],
0893                     sk->sk_socket->ops);
0894             smp_store_release(&saved_tcpv6_prot, prot);
0895         }
0896         mutex_unlock(&tcpv6_prot_mutex);
0897     }
0898 
0899     if (ip_ver == TLSV4 &&
0900         unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
0901         mutex_lock(&tcpv4_prot_mutex);
0902         if (likely(prot != saved_tcpv4_prot)) {
0903             build_protos(tls_prots[TLSV4], prot);
0904             build_proto_ops(tls_proto_ops[TLSV4],
0905                     sk->sk_socket->ops);
0906             smp_store_release(&saved_tcpv4_prot, prot);
0907         }
0908         mutex_unlock(&tcpv4_prot_mutex);
0909     }
0910 }
0911 
0912 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
0913              const struct proto *base)
0914 {
0915     prot[TLS_BASE][TLS_BASE] = *base;
0916     prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
0917     prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
0918     prot[TLS_BASE][TLS_BASE].close      = tls_sk_proto_close;
0919 
0920     prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
0921     prot[TLS_SW][TLS_BASE].sendmsg      = tls_sw_sendmsg;
0922     prot[TLS_SW][TLS_BASE].sendpage     = tls_sw_sendpage;
0923 
0924     prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
0925     prot[TLS_BASE][TLS_SW].recvmsg        = tls_sw_recvmsg;
0926     prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
0927     prot[TLS_BASE][TLS_SW].close          = tls_sk_proto_close;
0928 
0929     prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
0930     prot[TLS_SW][TLS_SW].recvmsg        = tls_sw_recvmsg;
0931     prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
0932     prot[TLS_SW][TLS_SW].close      = tls_sk_proto_close;
0933 
0934 #ifdef CONFIG_TLS_DEVICE
0935     prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
0936     prot[TLS_HW][TLS_BASE].sendmsg      = tls_device_sendmsg;
0937     prot[TLS_HW][TLS_BASE].sendpage     = tls_device_sendpage;
0938 
0939     prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
0940     prot[TLS_HW][TLS_SW].sendmsg        = tls_device_sendmsg;
0941     prot[TLS_HW][TLS_SW].sendpage       = tls_device_sendpage;
0942 
0943     prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
0944 
0945     prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
0946 
0947     prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
0948 #endif
0949 #ifdef CONFIG_TLS_TOE
0950     prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
0951     prot[TLS_HW_RECORD][TLS_HW_RECORD].hash     = tls_toe_hash;
0952     prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash   = tls_toe_unhash;
0953 #endif
0954 }
0955 
0956 static int tls_init(struct sock *sk)
0957 {
0958     struct tls_context *ctx;
0959     int rc = 0;
0960 
0961     tls_build_proto(sk);
0962 
0963 #ifdef CONFIG_TLS_TOE
0964     if (tls_toe_bypass(sk))
0965         return 0;
0966 #endif
0967 
0968     /* The TLS ulp is currently supported only for TCP sockets
0969      * in ESTABLISHED state.
0970      * Supporting sockets in LISTEN state will require us
0971      * to modify the accept implementation to clone rather then
0972      * share the ulp context.
0973      */
0974     if (sk->sk_state != TCP_ESTABLISHED)
0975         return -ENOTCONN;
0976 
0977     /* allocate tls context */
0978     write_lock_bh(&sk->sk_callback_lock);
0979     ctx = tls_ctx_create(sk);
0980     if (!ctx) {
0981         rc = -ENOMEM;
0982         goto out;
0983     }
0984 
0985     ctx->tx_conf = TLS_BASE;
0986     ctx->rx_conf = TLS_BASE;
0987     update_sk_prot(sk, ctx);
0988 out:
0989     write_unlock_bh(&sk->sk_callback_lock);
0990     return rc;
0991 }
0992 
0993 static void tls_update(struct sock *sk, struct proto *p,
0994                void (*write_space)(struct sock *sk))
0995 {
0996     struct tls_context *ctx;
0997 
0998     WARN_ON_ONCE(sk->sk_prot == p);
0999 
1000     ctx = tls_get_ctx(sk);
1001     if (likely(ctx)) {
1002         ctx->sk_write_space = write_space;
1003         ctx->sk_proto = p;
1004     } else {
1005         /* Pairs with lockless read in sk_clone_lock(). */
1006         WRITE_ONCE(sk->sk_prot, p);
1007         sk->sk_write_space = write_space;
1008     }
1009 }
1010 
1011 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1012 {
1013     u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1014 
1015     switch (config) {
1016     case TLS_BASE:
1017         return TLS_CONF_BASE;
1018     case TLS_SW:
1019         return TLS_CONF_SW;
1020     case TLS_HW:
1021         return TLS_CONF_HW;
1022     case TLS_HW_RECORD:
1023         return TLS_CONF_HW_RECORD;
1024     }
1025     return 0;
1026 }
1027 
1028 static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
1029 {
1030     u16 version, cipher_type;
1031     struct tls_context *ctx;
1032     struct nlattr *start;
1033     int err;
1034 
1035     start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1036     if (!start)
1037         return -EMSGSIZE;
1038 
1039     rcu_read_lock();
1040     ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1041     if (!ctx) {
1042         err = 0;
1043         goto nla_failure;
1044     }
1045     version = ctx->prot_info.version;
1046     if (version) {
1047         err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1048         if (err)
1049             goto nla_failure;
1050     }
1051     cipher_type = ctx->prot_info.cipher_type;
1052     if (cipher_type) {
1053         err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1054         if (err)
1055             goto nla_failure;
1056     }
1057     err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1058     if (err)
1059         goto nla_failure;
1060 
1061     err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1062     if (err)
1063         goto nla_failure;
1064 
1065     if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1066         err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1067         if (err)
1068             goto nla_failure;
1069     }
1070     if (ctx->rx_no_pad) {
1071         err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1072         if (err)
1073             goto nla_failure;
1074     }
1075 
1076     rcu_read_unlock();
1077     nla_nest_end(skb, start);
1078     return 0;
1079 
1080 nla_failure:
1081     rcu_read_unlock();
1082     nla_nest_cancel(skb, start);
1083     return err;
1084 }
1085 
1086 static size_t tls_get_info_size(const struct sock *sk)
1087 {
1088     size_t size = 0;
1089 
1090     size += nla_total_size(0) +     /* INET_ULP_INFO_TLS */
1091         nla_total_size(sizeof(u16)) +   /* TLS_INFO_VERSION */
1092         nla_total_size(sizeof(u16)) +   /* TLS_INFO_CIPHER */
1093         nla_total_size(sizeof(u16)) +   /* TLS_INFO_RXCONF */
1094         nla_total_size(sizeof(u16)) +   /* TLS_INFO_TXCONF */
1095         nla_total_size(0) +     /* TLS_INFO_ZC_RO_TX */
1096         nla_total_size(0) +     /* TLS_INFO_RX_NO_PAD */
1097         0;
1098 
1099     return size;
1100 }
1101 
1102 static int __net_init tls_init_net(struct net *net)
1103 {
1104     int err;
1105 
1106     net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1107     if (!net->mib.tls_statistics)
1108         return -ENOMEM;
1109 
1110     err = tls_proc_init(net);
1111     if (err)
1112         goto err_free_stats;
1113 
1114     return 0;
1115 err_free_stats:
1116     free_percpu(net->mib.tls_statistics);
1117     return err;
1118 }
1119 
1120 static void __net_exit tls_exit_net(struct net *net)
1121 {
1122     tls_proc_fini(net);
1123     free_percpu(net->mib.tls_statistics);
1124 }
1125 
1126 static struct pernet_operations tls_proc_ops = {
1127     .init = tls_init_net,
1128     .exit = tls_exit_net,
1129 };
1130 
1131 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1132     .name           = "tls",
1133     .owner          = THIS_MODULE,
1134     .init           = tls_init,
1135     .update         = tls_update,
1136     .get_info       = tls_get_info,
1137     .get_info_size      = tls_get_info_size,
1138 };
1139 
1140 static int __init tls_register(void)
1141 {
1142     int err;
1143 
1144     err = register_pernet_subsys(&tls_proc_ops);
1145     if (err)
1146         return err;
1147 
1148     err = tls_strp_dev_init();
1149     if (err)
1150         goto err_pernet;
1151 
1152     err = tls_device_init();
1153     if (err)
1154         goto err_strp;
1155 
1156     tcp_register_ulp(&tcp_tls_ulp_ops);
1157 
1158     return 0;
1159 err_strp:
1160     tls_strp_dev_exit();
1161 err_pernet:
1162     unregister_pernet_subsys(&tls_proc_ops);
1163     return err;
1164 }
1165 
1166 static void __exit tls_unregister(void)
1167 {
1168     tcp_unregister_ulp(&tcp_tls_ulp_ops);
1169     tls_strp_dev_exit();
1170     tls_device_cleanup();
1171     unregister_pernet_subsys(&tls_proc_ops);
1172 }
1173 
1174 module_init(tls_register);
1175 module_exit(tls_unregister);