Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
0003  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
0004  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
0005  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
0006  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
0007  * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
0008  *
0009  * This software is available to you under a choice of one of two
0010  * licenses.  You may choose to be licensed under the terms of the GNU
0011  * General Public License (GPL) Version 2, available from the file
0012  * COPYING in the main directory of this source tree, or the
0013  * OpenIB.org BSD license below:
0014  *
0015  *     Redistribution and use in source and binary forms, with or
0016  *     without modification, are permitted provided that the following
0017  *     conditions are met:
0018  *
0019  *      - Redistributions of source code must retain the above
0020  *        copyright notice, this list of conditions and the following
0021  *        disclaimer.
0022  *
0023  *      - Redistributions in binary form must reproduce the above
0024  *        copyright notice, this list of conditions and the following
0025  *        disclaimer in the documentation and/or other materials
0026  *        provided with the distribution.
0027  *
0028  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
0029  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
0030  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
0031  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
0032  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
0033  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
0034  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
0035  * SOFTWARE.
0036  */
0037 
0038 #include <linux/bug.h>
0039 #include <linux/sched/signal.h>
0040 #include <linux/module.h>
0041 #include <linux/splice.h>
0042 #include <crypto/aead.h>
0043 
0044 #include <net/strparser.h>
0045 #include <net/tls.h>
0046 
0047 #include "tls.h"
0048 
0049 struct tls_decrypt_arg {
0050     struct_group(inargs,
0051     bool zc;
0052     bool async;
0053     u8 tail;
0054     );
0055 
0056     struct sk_buff *skb;
0057 };
0058 
0059 struct tls_decrypt_ctx {
0060     u8 iv[MAX_IV_SIZE];
0061     u8 aad[TLS_MAX_AAD_SIZE];
0062     u8 tail;
0063     struct scatterlist sg[];
0064 };
0065 
0066 noinline void tls_err_abort(struct sock *sk, int err)
0067 {
0068     WARN_ON_ONCE(err >= 0);
0069     /* sk->sk_err should contain a positive error code. */
0070     sk->sk_err = -err;
0071     sk_error_report(sk);
0072 }
0073 
0074 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
0075                      unsigned int recursion_level)
0076 {
0077         int start = skb_headlen(skb);
0078         int i, chunk = start - offset;
0079         struct sk_buff *frag_iter;
0080         int elt = 0;
0081 
0082         if (unlikely(recursion_level >= 24))
0083                 return -EMSGSIZE;
0084 
0085         if (chunk > 0) {
0086                 if (chunk > len)
0087                         chunk = len;
0088                 elt++;
0089                 len -= chunk;
0090                 if (len == 0)
0091                         return elt;
0092                 offset += chunk;
0093         }
0094 
0095         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
0096                 int end;
0097 
0098                 WARN_ON(start > offset + len);
0099 
0100                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
0101                 chunk = end - offset;
0102                 if (chunk > 0) {
0103                         if (chunk > len)
0104                                 chunk = len;
0105                         elt++;
0106                         len -= chunk;
0107                         if (len == 0)
0108                                 return elt;
0109                         offset += chunk;
0110                 }
0111                 start = end;
0112         }
0113 
0114         if (unlikely(skb_has_frag_list(skb))) {
0115                 skb_walk_frags(skb, frag_iter) {
0116                         int end, ret;
0117 
0118                         WARN_ON(start > offset + len);
0119 
0120                         end = start + frag_iter->len;
0121                         chunk = end - offset;
0122                         if (chunk > 0) {
0123                                 if (chunk > len)
0124                                         chunk = len;
0125                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
0126                                                 recursion_level + 1);
0127                                 if (unlikely(ret < 0))
0128                                         return ret;
0129                                 elt += ret;
0130                                 len -= chunk;
0131                                 if (len == 0)
0132                                         return elt;
0133                                 offset += chunk;
0134                         }
0135                         start = end;
0136                 }
0137         }
0138         BUG_ON(len);
0139         return elt;
0140 }
0141 
0142 /* Return the number of scatterlist elements required to completely map the
0143  * skb, or -EMSGSIZE if the recursion depth is exceeded.
0144  */
0145 static int skb_nsg(struct sk_buff *skb, int offset, int len)
0146 {
0147         return __skb_nsg(skb, offset, len, 0);
0148 }
0149 
0150 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
0151                   struct tls_decrypt_arg *darg)
0152 {
0153     struct strp_msg *rxm = strp_msg(skb);
0154     struct tls_msg *tlm = tls_msg(skb);
0155     int sub = 0;
0156 
0157     /* Determine zero-padding length */
0158     if (prot->version == TLS_1_3_VERSION) {
0159         int offset = rxm->full_len - TLS_TAG_SIZE - 1;
0160         char content_type = darg->zc ? darg->tail : 0;
0161         int err;
0162 
0163         while (content_type == 0) {
0164             if (offset < prot->prepend_size)
0165                 return -EBADMSG;
0166             err = skb_copy_bits(skb, rxm->offset + offset,
0167                         &content_type, 1);
0168             if (err)
0169                 return err;
0170             if (content_type)
0171                 break;
0172             sub++;
0173             offset--;
0174         }
0175         tlm->control = content_type;
0176     }
0177     return sub;
0178 }
0179 
0180 static void tls_decrypt_done(struct crypto_async_request *req, int err)
0181 {
0182     struct aead_request *aead_req = (struct aead_request *)req;
0183     struct scatterlist *sgout = aead_req->dst;
0184     struct scatterlist *sgin = aead_req->src;
0185     struct tls_sw_context_rx *ctx;
0186     struct tls_context *tls_ctx;
0187     struct scatterlist *sg;
0188     unsigned int pages;
0189     struct sock *sk;
0190 
0191     sk = (struct sock *)req->data;
0192     tls_ctx = tls_get_ctx(sk);
0193     ctx = tls_sw_ctx_rx(tls_ctx);
0194 
0195     /* Propagate if there was an err */
0196     if (err) {
0197         if (err == -EBADMSG)
0198             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
0199         ctx->async_wait.err = err;
0200         tls_err_abort(sk, err);
0201     }
0202 
0203     /* Free the destination pages if skb was not decrypted inplace */
0204     if (sgout != sgin) {
0205         /* Skip the first S/G entry as it points to AAD */
0206         for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
0207             if (!sg)
0208                 break;
0209             put_page(sg_page(sg));
0210         }
0211     }
0212 
0213     kfree(aead_req);
0214 
0215     spin_lock_bh(&ctx->decrypt_compl_lock);
0216     if (!atomic_dec_return(&ctx->decrypt_pending))
0217         complete(&ctx->async_wait.completion);
0218     spin_unlock_bh(&ctx->decrypt_compl_lock);
0219 }
0220 
0221 static int tls_do_decryption(struct sock *sk,
0222                  struct scatterlist *sgin,
0223                  struct scatterlist *sgout,
0224                  char *iv_recv,
0225                  size_t data_len,
0226                  struct aead_request *aead_req,
0227                  struct tls_decrypt_arg *darg)
0228 {
0229     struct tls_context *tls_ctx = tls_get_ctx(sk);
0230     struct tls_prot_info *prot = &tls_ctx->prot_info;
0231     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
0232     int ret;
0233 
0234     aead_request_set_tfm(aead_req, ctx->aead_recv);
0235     aead_request_set_ad(aead_req, prot->aad_size);
0236     aead_request_set_crypt(aead_req, sgin, sgout,
0237                    data_len + prot->tag_size,
0238                    (u8 *)iv_recv);
0239 
0240     if (darg->async) {
0241         aead_request_set_callback(aead_req,
0242                       CRYPTO_TFM_REQ_MAY_BACKLOG,
0243                       tls_decrypt_done, sk);
0244         atomic_inc(&ctx->decrypt_pending);
0245     } else {
0246         aead_request_set_callback(aead_req,
0247                       CRYPTO_TFM_REQ_MAY_BACKLOG,
0248                       crypto_req_done, &ctx->async_wait);
0249     }
0250 
0251     ret = crypto_aead_decrypt(aead_req);
0252     if (ret == -EINPROGRESS) {
0253         if (darg->async)
0254             return 0;
0255 
0256         ret = crypto_wait_req(ret, &ctx->async_wait);
0257     }
0258     darg->async = false;
0259 
0260     return ret;
0261 }
0262 
0263 static void tls_trim_both_msgs(struct sock *sk, int target_size)
0264 {
0265     struct tls_context *tls_ctx = tls_get_ctx(sk);
0266     struct tls_prot_info *prot = &tls_ctx->prot_info;
0267     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0268     struct tls_rec *rec = ctx->open_rec;
0269 
0270     sk_msg_trim(sk, &rec->msg_plaintext, target_size);
0271     if (target_size > 0)
0272         target_size += prot->overhead_size;
0273     sk_msg_trim(sk, &rec->msg_encrypted, target_size);
0274 }
0275 
0276 static int tls_alloc_encrypted_msg(struct sock *sk, int len)
0277 {
0278     struct tls_context *tls_ctx = tls_get_ctx(sk);
0279     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0280     struct tls_rec *rec = ctx->open_rec;
0281     struct sk_msg *msg_en = &rec->msg_encrypted;
0282 
0283     return sk_msg_alloc(sk, msg_en, len, 0);
0284 }
0285 
0286 static int tls_clone_plaintext_msg(struct sock *sk, int required)
0287 {
0288     struct tls_context *tls_ctx = tls_get_ctx(sk);
0289     struct tls_prot_info *prot = &tls_ctx->prot_info;
0290     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0291     struct tls_rec *rec = ctx->open_rec;
0292     struct sk_msg *msg_pl = &rec->msg_plaintext;
0293     struct sk_msg *msg_en = &rec->msg_encrypted;
0294     int skip, len;
0295 
0296     /* We add page references worth len bytes from encrypted sg
0297      * at the end of plaintext sg. It is guaranteed that msg_en
0298      * has enough required room (ensured by caller).
0299      */
0300     len = required - msg_pl->sg.size;
0301 
0302     /* Skip initial bytes in msg_en's data to be able to use
0303      * same offset of both plain and encrypted data.
0304      */
0305     skip = prot->prepend_size + msg_pl->sg.size;
0306 
0307     return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
0308 }
0309 
0310 static struct tls_rec *tls_get_rec(struct sock *sk)
0311 {
0312     struct tls_context *tls_ctx = tls_get_ctx(sk);
0313     struct tls_prot_info *prot = &tls_ctx->prot_info;
0314     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0315     struct sk_msg *msg_pl, *msg_en;
0316     struct tls_rec *rec;
0317     int mem_size;
0318 
0319     mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
0320 
0321     rec = kzalloc(mem_size, sk->sk_allocation);
0322     if (!rec)
0323         return NULL;
0324 
0325     msg_pl = &rec->msg_plaintext;
0326     msg_en = &rec->msg_encrypted;
0327 
0328     sk_msg_init(msg_pl);
0329     sk_msg_init(msg_en);
0330 
0331     sg_init_table(rec->sg_aead_in, 2);
0332     sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
0333     sg_unmark_end(&rec->sg_aead_in[1]);
0334 
0335     sg_init_table(rec->sg_aead_out, 2);
0336     sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
0337     sg_unmark_end(&rec->sg_aead_out[1]);
0338 
0339     return rec;
0340 }
0341 
0342 static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
0343 {
0344     sk_msg_free(sk, &rec->msg_encrypted);
0345     sk_msg_free(sk, &rec->msg_plaintext);
0346     kfree(rec);
0347 }
0348 
0349 static void tls_free_open_rec(struct sock *sk)
0350 {
0351     struct tls_context *tls_ctx = tls_get_ctx(sk);
0352     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0353     struct tls_rec *rec = ctx->open_rec;
0354 
0355     if (rec) {
0356         tls_free_rec(sk, rec);
0357         ctx->open_rec = NULL;
0358     }
0359 }
0360 
0361 int tls_tx_records(struct sock *sk, int flags)
0362 {
0363     struct tls_context *tls_ctx = tls_get_ctx(sk);
0364     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0365     struct tls_rec *rec, *tmp;
0366     struct sk_msg *msg_en;
0367     int tx_flags, rc = 0;
0368 
0369     if (tls_is_partially_sent_record(tls_ctx)) {
0370         rec = list_first_entry(&ctx->tx_list,
0371                        struct tls_rec, list);
0372 
0373         if (flags == -1)
0374             tx_flags = rec->tx_flags;
0375         else
0376             tx_flags = flags;
0377 
0378         rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
0379         if (rc)
0380             goto tx_err;
0381 
0382         /* Full record has been transmitted.
0383          * Remove the head of tx_list
0384          */
0385         list_del(&rec->list);
0386         sk_msg_free(sk, &rec->msg_plaintext);
0387         kfree(rec);
0388     }
0389 
0390     /* Tx all ready records */
0391     list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
0392         if (READ_ONCE(rec->tx_ready)) {
0393             if (flags == -1)
0394                 tx_flags = rec->tx_flags;
0395             else
0396                 tx_flags = flags;
0397 
0398             msg_en = &rec->msg_encrypted;
0399             rc = tls_push_sg(sk, tls_ctx,
0400                      &msg_en->sg.data[msg_en->sg.curr],
0401                      0, tx_flags);
0402             if (rc)
0403                 goto tx_err;
0404 
0405             list_del(&rec->list);
0406             sk_msg_free(sk, &rec->msg_plaintext);
0407             kfree(rec);
0408         } else {
0409             break;
0410         }
0411     }
0412 
0413 tx_err:
0414     if (rc < 0 && rc != -EAGAIN)
0415         tls_err_abort(sk, -EBADMSG);
0416 
0417     return rc;
0418 }
0419 
0420 static void tls_encrypt_done(struct crypto_async_request *req, int err)
0421 {
0422     struct aead_request *aead_req = (struct aead_request *)req;
0423     struct sock *sk = req->data;
0424     struct tls_context *tls_ctx = tls_get_ctx(sk);
0425     struct tls_prot_info *prot = &tls_ctx->prot_info;
0426     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0427     struct scatterlist *sge;
0428     struct sk_msg *msg_en;
0429     struct tls_rec *rec;
0430     bool ready = false;
0431     int pending;
0432 
0433     rec = container_of(aead_req, struct tls_rec, aead_req);
0434     msg_en = &rec->msg_encrypted;
0435 
0436     sge = sk_msg_elem(msg_en, msg_en->sg.curr);
0437     sge->offset -= prot->prepend_size;
0438     sge->length += prot->prepend_size;
0439 
0440     /* Check if error is previously set on socket */
0441     if (err || sk->sk_err) {
0442         rec = NULL;
0443 
0444         /* If err is already set on socket, return the same code */
0445         if (sk->sk_err) {
0446             ctx->async_wait.err = -sk->sk_err;
0447         } else {
0448             ctx->async_wait.err = err;
0449             tls_err_abort(sk, err);
0450         }
0451     }
0452 
0453     if (rec) {
0454         struct tls_rec *first_rec;
0455 
0456         /* Mark the record as ready for transmission */
0457         smp_store_mb(rec->tx_ready, true);
0458 
0459         /* If received record is at head of tx_list, schedule tx */
0460         first_rec = list_first_entry(&ctx->tx_list,
0461                          struct tls_rec, list);
0462         if (rec == first_rec)
0463             ready = true;
0464     }
0465 
0466     spin_lock_bh(&ctx->encrypt_compl_lock);
0467     pending = atomic_dec_return(&ctx->encrypt_pending);
0468 
0469     if (!pending && ctx->async_notify)
0470         complete(&ctx->async_wait.completion);
0471     spin_unlock_bh(&ctx->encrypt_compl_lock);
0472 
0473     if (!ready)
0474         return;
0475 
0476     /* Schedule the transmission */
0477     if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
0478         schedule_delayed_work(&ctx->tx_work.work, 1);
0479 }
0480 
0481 static int tls_do_encryption(struct sock *sk,
0482                  struct tls_context *tls_ctx,
0483                  struct tls_sw_context_tx *ctx,
0484                  struct aead_request *aead_req,
0485                  size_t data_len, u32 start)
0486 {
0487     struct tls_prot_info *prot = &tls_ctx->prot_info;
0488     struct tls_rec *rec = ctx->open_rec;
0489     struct sk_msg *msg_en = &rec->msg_encrypted;
0490     struct scatterlist *sge = sk_msg_elem(msg_en, start);
0491     int rc, iv_offset = 0;
0492 
0493     /* For CCM based ciphers, first byte of IV is a constant */
0494     switch (prot->cipher_type) {
0495     case TLS_CIPHER_AES_CCM_128:
0496         rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
0497         iv_offset = 1;
0498         break;
0499     case TLS_CIPHER_SM4_CCM:
0500         rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
0501         iv_offset = 1;
0502         break;
0503     }
0504 
0505     memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
0506            prot->iv_size + prot->salt_size);
0507 
0508     tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
0509                 tls_ctx->tx.rec_seq);
0510 
0511     sge->offset += prot->prepend_size;
0512     sge->length -= prot->prepend_size;
0513 
0514     msg_en->sg.curr = start;
0515 
0516     aead_request_set_tfm(aead_req, ctx->aead_send);
0517     aead_request_set_ad(aead_req, prot->aad_size);
0518     aead_request_set_crypt(aead_req, rec->sg_aead_in,
0519                    rec->sg_aead_out,
0520                    data_len, rec->iv_data);
0521 
0522     aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
0523                   tls_encrypt_done, sk);
0524 
0525     /* Add the record in tx_list */
0526     list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
0527     atomic_inc(&ctx->encrypt_pending);
0528 
0529     rc = crypto_aead_encrypt(aead_req);
0530     if (!rc || rc != -EINPROGRESS) {
0531         atomic_dec(&ctx->encrypt_pending);
0532         sge->offset -= prot->prepend_size;
0533         sge->length += prot->prepend_size;
0534     }
0535 
0536     if (!rc) {
0537         WRITE_ONCE(rec->tx_ready, true);
0538     } else if (rc != -EINPROGRESS) {
0539         list_del(&rec->list);
0540         return rc;
0541     }
0542 
0543     /* Unhook the record from context if encryption is not failure */
0544     ctx->open_rec = NULL;
0545     tls_advance_record_sn(sk, prot, &tls_ctx->tx);
0546     return rc;
0547 }
0548 
0549 static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
0550                  struct tls_rec **to, struct sk_msg *msg_opl,
0551                  struct sk_msg *msg_oen, u32 split_point,
0552                  u32 tx_overhead_size, u32 *orig_end)
0553 {
0554     u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
0555     struct scatterlist *sge, *osge, *nsge;
0556     u32 orig_size = msg_opl->sg.size;
0557     struct scatterlist tmp = { };
0558     struct sk_msg *msg_npl;
0559     struct tls_rec *new;
0560     int ret;
0561 
0562     new = tls_get_rec(sk);
0563     if (!new)
0564         return -ENOMEM;
0565     ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
0566                tx_overhead_size, 0);
0567     if (ret < 0) {
0568         tls_free_rec(sk, new);
0569         return ret;
0570     }
0571 
0572     *orig_end = msg_opl->sg.end;
0573     i = msg_opl->sg.start;
0574     sge = sk_msg_elem(msg_opl, i);
0575     while (apply && sge->length) {
0576         if (sge->length > apply) {
0577             u32 len = sge->length - apply;
0578 
0579             get_page(sg_page(sge));
0580             sg_set_page(&tmp, sg_page(sge), len,
0581                     sge->offset + apply);
0582             sge->length = apply;
0583             bytes += apply;
0584             apply = 0;
0585         } else {
0586             apply -= sge->length;
0587             bytes += sge->length;
0588         }
0589 
0590         sk_msg_iter_var_next(i);
0591         if (i == msg_opl->sg.end)
0592             break;
0593         sge = sk_msg_elem(msg_opl, i);
0594     }
0595 
0596     msg_opl->sg.end = i;
0597     msg_opl->sg.curr = i;
0598     msg_opl->sg.copybreak = 0;
0599     msg_opl->apply_bytes = 0;
0600     msg_opl->sg.size = bytes;
0601 
0602     msg_npl = &new->msg_plaintext;
0603     msg_npl->apply_bytes = apply;
0604     msg_npl->sg.size = orig_size - bytes;
0605 
0606     j = msg_npl->sg.start;
0607     nsge = sk_msg_elem(msg_npl, j);
0608     if (tmp.length) {
0609         memcpy(nsge, &tmp, sizeof(*nsge));
0610         sk_msg_iter_var_next(j);
0611         nsge = sk_msg_elem(msg_npl, j);
0612     }
0613 
0614     osge = sk_msg_elem(msg_opl, i);
0615     while (osge->length) {
0616         memcpy(nsge, osge, sizeof(*nsge));
0617         sg_unmark_end(nsge);
0618         sk_msg_iter_var_next(i);
0619         sk_msg_iter_var_next(j);
0620         if (i == *orig_end)
0621             break;
0622         osge = sk_msg_elem(msg_opl, i);
0623         nsge = sk_msg_elem(msg_npl, j);
0624     }
0625 
0626     msg_npl->sg.end = j;
0627     msg_npl->sg.curr = j;
0628     msg_npl->sg.copybreak = 0;
0629 
0630     *to = new;
0631     return 0;
0632 }
0633 
0634 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
0635                   struct tls_rec *from, u32 orig_end)
0636 {
0637     struct sk_msg *msg_npl = &from->msg_plaintext;
0638     struct sk_msg *msg_opl = &to->msg_plaintext;
0639     struct scatterlist *osge, *nsge;
0640     u32 i, j;
0641 
0642     i = msg_opl->sg.end;
0643     sk_msg_iter_var_prev(i);
0644     j = msg_npl->sg.start;
0645 
0646     osge = sk_msg_elem(msg_opl, i);
0647     nsge = sk_msg_elem(msg_npl, j);
0648 
0649     if (sg_page(osge) == sg_page(nsge) &&
0650         osge->offset + osge->length == nsge->offset) {
0651         osge->length += nsge->length;
0652         put_page(sg_page(nsge));
0653     }
0654 
0655     msg_opl->sg.end = orig_end;
0656     msg_opl->sg.curr = orig_end;
0657     msg_opl->sg.copybreak = 0;
0658     msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
0659     msg_opl->sg.size += msg_npl->sg.size;
0660 
0661     sk_msg_free(sk, &to->msg_encrypted);
0662     sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
0663 
0664     kfree(from);
0665 }
0666 
0667 static int tls_push_record(struct sock *sk, int flags,
0668                unsigned char record_type)
0669 {
0670     struct tls_context *tls_ctx = tls_get_ctx(sk);
0671     struct tls_prot_info *prot = &tls_ctx->prot_info;
0672     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0673     struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
0674     u32 i, split_point, orig_end;
0675     struct sk_msg *msg_pl, *msg_en;
0676     struct aead_request *req;
0677     bool split;
0678     int rc;
0679 
0680     if (!rec)
0681         return 0;
0682 
0683     msg_pl = &rec->msg_plaintext;
0684     msg_en = &rec->msg_encrypted;
0685 
0686     split_point = msg_pl->apply_bytes;
0687     split = split_point && split_point < msg_pl->sg.size;
0688     if (unlikely((!split &&
0689               msg_pl->sg.size +
0690               prot->overhead_size > msg_en->sg.size) ||
0691              (split &&
0692               split_point +
0693               prot->overhead_size > msg_en->sg.size))) {
0694         split = true;
0695         split_point = msg_en->sg.size;
0696     }
0697     if (split) {
0698         rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
0699                        split_point, prot->overhead_size,
0700                        &orig_end);
0701         if (rc < 0)
0702             return rc;
0703         /* This can happen if above tls_split_open_record allocates
0704          * a single large encryption buffer instead of two smaller
0705          * ones. In this case adjust pointers and continue without
0706          * split.
0707          */
0708         if (!msg_pl->sg.size) {
0709             tls_merge_open_record(sk, rec, tmp, orig_end);
0710             msg_pl = &rec->msg_plaintext;
0711             msg_en = &rec->msg_encrypted;
0712             split = false;
0713         }
0714         sk_msg_trim(sk, msg_en, msg_pl->sg.size +
0715                 prot->overhead_size);
0716     }
0717 
0718     rec->tx_flags = flags;
0719     req = &rec->aead_req;
0720 
0721     i = msg_pl->sg.end;
0722     sk_msg_iter_var_prev(i);
0723 
0724     rec->content_type = record_type;
0725     if (prot->version == TLS_1_3_VERSION) {
0726         /* Add content type to end of message.  No padding added */
0727         sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
0728         sg_mark_end(&rec->sg_content_type);
0729         sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
0730              &rec->sg_content_type);
0731     } else {
0732         sg_mark_end(sk_msg_elem(msg_pl, i));
0733     }
0734 
0735     if (msg_pl->sg.end < msg_pl->sg.start) {
0736         sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
0737              MAX_SKB_FRAGS - msg_pl->sg.start + 1,
0738              msg_pl->sg.data);
0739     }
0740 
0741     i = msg_pl->sg.start;
0742     sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
0743 
0744     i = msg_en->sg.end;
0745     sk_msg_iter_var_prev(i);
0746     sg_mark_end(sk_msg_elem(msg_en, i));
0747 
0748     i = msg_en->sg.start;
0749     sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
0750 
0751     tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
0752              tls_ctx->tx.rec_seq, record_type, prot);
0753 
0754     tls_fill_prepend(tls_ctx,
0755              page_address(sg_page(&msg_en->sg.data[i])) +
0756              msg_en->sg.data[i].offset,
0757              msg_pl->sg.size + prot->tail_size,
0758              record_type);
0759 
0760     tls_ctx->pending_open_record_frags = false;
0761 
0762     rc = tls_do_encryption(sk, tls_ctx, ctx, req,
0763                    msg_pl->sg.size + prot->tail_size, i);
0764     if (rc < 0) {
0765         if (rc != -EINPROGRESS) {
0766             tls_err_abort(sk, -EBADMSG);
0767             if (split) {
0768                 tls_ctx->pending_open_record_frags = true;
0769                 tls_merge_open_record(sk, rec, tmp, orig_end);
0770             }
0771         }
0772         ctx->async_capable = 1;
0773         return rc;
0774     } else if (split) {
0775         msg_pl = &tmp->msg_plaintext;
0776         msg_en = &tmp->msg_encrypted;
0777         sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
0778         tls_ctx->pending_open_record_frags = true;
0779         ctx->open_rec = tmp;
0780     }
0781 
0782     return tls_tx_records(sk, flags);
0783 }
0784 
0785 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
0786                    bool full_record, u8 record_type,
0787                    ssize_t *copied, int flags)
0788 {
0789     struct tls_context *tls_ctx = tls_get_ctx(sk);
0790     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0791     struct sk_msg msg_redir = { };
0792     struct sk_psock *psock;
0793     struct sock *sk_redir;
0794     struct tls_rec *rec;
0795     bool enospc, policy;
0796     int err = 0, send;
0797     u32 delta = 0;
0798 
0799     policy = !(flags & MSG_SENDPAGE_NOPOLICY);
0800     psock = sk_psock_get(sk);
0801     if (!psock || !policy) {
0802         err = tls_push_record(sk, flags, record_type);
0803         if (err && sk->sk_err == EBADMSG) {
0804             *copied -= sk_msg_free(sk, msg);
0805             tls_free_open_rec(sk);
0806             err = -sk->sk_err;
0807         }
0808         if (psock)
0809             sk_psock_put(sk, psock);
0810         return err;
0811     }
0812 more_data:
0813     enospc = sk_msg_full(msg);
0814     if (psock->eval == __SK_NONE) {
0815         delta = msg->sg.size;
0816         psock->eval = sk_psock_msg_verdict(sk, psock, msg);
0817         delta -= msg->sg.size;
0818     }
0819     if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
0820         !enospc && !full_record) {
0821         err = -ENOSPC;
0822         goto out_err;
0823     }
0824     msg->cork_bytes = 0;
0825     send = msg->sg.size;
0826     if (msg->apply_bytes && msg->apply_bytes < send)
0827         send = msg->apply_bytes;
0828 
0829     switch (psock->eval) {
0830     case __SK_PASS:
0831         err = tls_push_record(sk, flags, record_type);
0832         if (err && sk->sk_err == EBADMSG) {
0833             *copied -= sk_msg_free(sk, msg);
0834             tls_free_open_rec(sk);
0835             err = -sk->sk_err;
0836             goto out_err;
0837         }
0838         break;
0839     case __SK_REDIRECT:
0840         sk_redir = psock->sk_redir;
0841         memcpy(&msg_redir, msg, sizeof(*msg));
0842         if (msg->apply_bytes < send)
0843             msg->apply_bytes = 0;
0844         else
0845             msg->apply_bytes -= send;
0846         sk_msg_return_zero(sk, msg, send);
0847         msg->sg.size -= send;
0848         release_sock(sk);
0849         err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
0850         lock_sock(sk);
0851         if (err < 0) {
0852             *copied -= sk_msg_free_nocharge(sk, &msg_redir);
0853             msg->sg.size = 0;
0854         }
0855         if (msg->sg.size == 0)
0856             tls_free_open_rec(sk);
0857         break;
0858     case __SK_DROP:
0859     default:
0860         sk_msg_free_partial(sk, msg, send);
0861         if (msg->apply_bytes < send)
0862             msg->apply_bytes = 0;
0863         else
0864             msg->apply_bytes -= send;
0865         if (msg->sg.size == 0)
0866             tls_free_open_rec(sk);
0867         *copied -= (send + delta);
0868         err = -EACCES;
0869     }
0870 
0871     if (likely(!err)) {
0872         bool reset_eval = !ctx->open_rec;
0873 
0874         rec = ctx->open_rec;
0875         if (rec) {
0876             msg = &rec->msg_plaintext;
0877             if (!msg->apply_bytes)
0878                 reset_eval = true;
0879         }
0880         if (reset_eval) {
0881             psock->eval = __SK_NONE;
0882             if (psock->sk_redir) {
0883                 sock_put(psock->sk_redir);
0884                 psock->sk_redir = NULL;
0885             }
0886         }
0887         if (rec)
0888             goto more_data;
0889     }
0890  out_err:
0891     sk_psock_put(sk, psock);
0892     return err;
0893 }
0894 
0895 static int tls_sw_push_pending_record(struct sock *sk, int flags)
0896 {
0897     struct tls_context *tls_ctx = tls_get_ctx(sk);
0898     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0899     struct tls_rec *rec = ctx->open_rec;
0900     struct sk_msg *msg_pl;
0901     size_t copied;
0902 
0903     if (!rec)
0904         return 0;
0905 
0906     msg_pl = &rec->msg_plaintext;
0907     copied = msg_pl->sg.size;
0908     if (!copied)
0909         return 0;
0910 
0911     return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
0912                    &copied, flags);
0913 }
0914 
0915 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
0916 {
0917     long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
0918     struct tls_context *tls_ctx = tls_get_ctx(sk);
0919     struct tls_prot_info *prot = &tls_ctx->prot_info;
0920     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0921     bool async_capable = ctx->async_capable;
0922     unsigned char record_type = TLS_RECORD_TYPE_DATA;
0923     bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
0924     bool eor = !(msg->msg_flags & MSG_MORE);
0925     size_t try_to_copy;
0926     ssize_t copied = 0;
0927     struct sk_msg *msg_pl, *msg_en;
0928     struct tls_rec *rec;
0929     int required_size;
0930     int num_async = 0;
0931     bool full_record;
0932     int record_room;
0933     int num_zc = 0;
0934     int orig_size;
0935     int ret = 0;
0936     int pending;
0937 
0938     if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
0939                    MSG_CMSG_COMPAT))
0940         return -EOPNOTSUPP;
0941 
0942     mutex_lock(&tls_ctx->tx_lock);
0943     lock_sock(sk);
0944 
0945     if (unlikely(msg->msg_controllen)) {
0946         ret = tls_process_cmsg(sk, msg, &record_type);
0947         if (ret) {
0948             if (ret == -EINPROGRESS)
0949                 num_async++;
0950             else if (ret != -EAGAIN)
0951                 goto send_end;
0952         }
0953     }
0954 
0955     while (msg_data_left(msg)) {
0956         if (sk->sk_err) {
0957             ret = -sk->sk_err;
0958             goto send_end;
0959         }
0960 
0961         if (ctx->open_rec)
0962             rec = ctx->open_rec;
0963         else
0964             rec = ctx->open_rec = tls_get_rec(sk);
0965         if (!rec) {
0966             ret = -ENOMEM;
0967             goto send_end;
0968         }
0969 
0970         msg_pl = &rec->msg_plaintext;
0971         msg_en = &rec->msg_encrypted;
0972 
0973         orig_size = msg_pl->sg.size;
0974         full_record = false;
0975         try_to_copy = msg_data_left(msg);
0976         record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
0977         if (try_to_copy >= record_room) {
0978             try_to_copy = record_room;
0979             full_record = true;
0980         }
0981 
0982         required_size = msg_pl->sg.size + try_to_copy +
0983                 prot->overhead_size;
0984 
0985         if (!sk_stream_memory_free(sk))
0986             goto wait_for_sndbuf;
0987 
0988 alloc_encrypted:
0989         ret = tls_alloc_encrypted_msg(sk, required_size);
0990         if (ret) {
0991             if (ret != -ENOSPC)
0992                 goto wait_for_memory;
0993 
0994             /* Adjust try_to_copy according to the amount that was
0995              * actually allocated. The difference is due
0996              * to max sg elements limit
0997              */
0998             try_to_copy -= required_size - msg_en->sg.size;
0999             full_record = true;
1000         }
1001 
1002         if (!is_kvec && (full_record || eor) && !async_capable) {
1003             u32 first = msg_pl->sg.end;
1004 
1005             ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1006                             msg_pl, try_to_copy);
1007             if (ret)
1008                 goto fallback_to_reg_send;
1009 
1010             num_zc++;
1011             copied += try_to_copy;
1012 
1013             sk_msg_sg_copy_set(msg_pl, first);
1014             ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1015                           record_type, &copied,
1016                           msg->msg_flags);
1017             if (ret) {
1018                 if (ret == -EINPROGRESS)
1019                     num_async++;
1020                 else if (ret == -ENOMEM)
1021                     goto wait_for_memory;
1022                 else if (ctx->open_rec && ret == -ENOSPC)
1023                     goto rollback_iter;
1024                 else if (ret != -EAGAIN)
1025                     goto send_end;
1026             }
1027             continue;
1028 rollback_iter:
1029             copied -= try_to_copy;
1030             sk_msg_sg_copy_clear(msg_pl, first);
1031             iov_iter_revert(&msg->msg_iter,
1032                     msg_pl->sg.size - orig_size);
1033 fallback_to_reg_send:
1034             sk_msg_trim(sk, msg_pl, orig_size);
1035         }
1036 
1037         required_size = msg_pl->sg.size + try_to_copy;
1038 
1039         ret = tls_clone_plaintext_msg(sk, required_size);
1040         if (ret) {
1041             if (ret != -ENOSPC)
1042                 goto send_end;
1043 
1044             /* Adjust try_to_copy according to the amount that was
1045              * actually allocated. The difference is due
1046              * to max sg elements limit
1047              */
1048             try_to_copy -= required_size - msg_pl->sg.size;
1049             full_record = true;
1050             sk_msg_trim(sk, msg_en,
1051                     msg_pl->sg.size + prot->overhead_size);
1052         }
1053 
1054         if (try_to_copy) {
1055             ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1056                                msg_pl, try_to_copy);
1057             if (ret < 0)
1058                 goto trim_sgl;
1059         }
1060 
1061         /* Open records defined only if successfully copied, otherwise
1062          * we would trim the sg but not reset the open record frags.
1063          */
1064         tls_ctx->pending_open_record_frags = true;
1065         copied += try_to_copy;
1066         if (full_record || eor) {
1067             ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1068                           record_type, &copied,
1069                           msg->msg_flags);
1070             if (ret) {
1071                 if (ret == -EINPROGRESS)
1072                     num_async++;
1073                 else if (ret == -ENOMEM)
1074                     goto wait_for_memory;
1075                 else if (ret != -EAGAIN) {
1076                     if (ret == -ENOSPC)
1077                         ret = 0;
1078                     goto send_end;
1079                 }
1080             }
1081         }
1082 
1083         continue;
1084 
1085 wait_for_sndbuf:
1086         set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1087 wait_for_memory:
1088         ret = sk_stream_wait_memory(sk, &timeo);
1089         if (ret) {
1090 trim_sgl:
1091             if (ctx->open_rec)
1092                 tls_trim_both_msgs(sk, orig_size);
1093             goto send_end;
1094         }
1095 
1096         if (ctx->open_rec && msg_en->sg.size < required_size)
1097             goto alloc_encrypted;
1098     }
1099 
1100     if (!num_async) {
1101         goto send_end;
1102     } else if (num_zc) {
1103         /* Wait for pending encryptions to get completed */
1104         spin_lock_bh(&ctx->encrypt_compl_lock);
1105         ctx->async_notify = true;
1106 
1107         pending = atomic_read(&ctx->encrypt_pending);
1108         spin_unlock_bh(&ctx->encrypt_compl_lock);
1109         if (pending)
1110             crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1111         else
1112             reinit_completion(&ctx->async_wait.completion);
1113 
1114         /* There can be no concurrent accesses, since we have no
1115          * pending encrypt operations
1116          */
1117         WRITE_ONCE(ctx->async_notify, false);
1118 
1119         if (ctx->async_wait.err) {
1120             ret = ctx->async_wait.err;
1121             copied = 0;
1122         }
1123     }
1124 
1125     /* Transmit if any encryptions have completed */
1126     if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1127         cancel_delayed_work(&ctx->tx_work.work);
1128         tls_tx_records(sk, msg->msg_flags);
1129     }
1130 
1131 send_end:
1132     ret = sk_stream_error(sk, msg->msg_flags, ret);
1133 
1134     release_sock(sk);
1135     mutex_unlock(&tls_ctx->tx_lock);
1136     return copied > 0 ? copied : ret;
1137 }
1138 
1139 static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1140                   int offset, size_t size, int flags)
1141 {
1142     long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1143     struct tls_context *tls_ctx = tls_get_ctx(sk);
1144     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1145     struct tls_prot_info *prot = &tls_ctx->prot_info;
1146     unsigned char record_type = TLS_RECORD_TYPE_DATA;
1147     struct sk_msg *msg_pl;
1148     struct tls_rec *rec;
1149     int num_async = 0;
1150     ssize_t copied = 0;
1151     bool full_record;
1152     int record_room;
1153     int ret = 0;
1154     bool eor;
1155 
1156     eor = !(flags & MSG_SENDPAGE_NOTLAST);
1157     sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1158 
1159     /* Call the sk_stream functions to manage the sndbuf mem. */
1160     while (size > 0) {
1161         size_t copy, required_size;
1162 
1163         if (sk->sk_err) {
1164             ret = -sk->sk_err;
1165             goto sendpage_end;
1166         }
1167 
1168         if (ctx->open_rec)
1169             rec = ctx->open_rec;
1170         else
1171             rec = ctx->open_rec = tls_get_rec(sk);
1172         if (!rec) {
1173             ret = -ENOMEM;
1174             goto sendpage_end;
1175         }
1176 
1177         msg_pl = &rec->msg_plaintext;
1178 
1179         full_record = false;
1180         record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1181         copy = size;
1182         if (copy >= record_room) {
1183             copy = record_room;
1184             full_record = true;
1185         }
1186 
1187         required_size = msg_pl->sg.size + copy + prot->overhead_size;
1188 
1189         if (!sk_stream_memory_free(sk))
1190             goto wait_for_sndbuf;
1191 alloc_payload:
1192         ret = tls_alloc_encrypted_msg(sk, required_size);
1193         if (ret) {
1194             if (ret != -ENOSPC)
1195                 goto wait_for_memory;
1196 
1197             /* Adjust copy according to the amount that was
1198              * actually allocated. The difference is due
1199              * to max sg elements limit
1200              */
1201             copy -= required_size - msg_pl->sg.size;
1202             full_record = true;
1203         }
1204 
1205         sk_msg_page_add(msg_pl, page, copy, offset);
1206         sk_mem_charge(sk, copy);
1207 
1208         offset += copy;
1209         size -= copy;
1210         copied += copy;
1211 
1212         tls_ctx->pending_open_record_frags = true;
1213         if (full_record || eor || sk_msg_full(msg_pl)) {
1214             ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1215                           record_type, &copied, flags);
1216             if (ret) {
1217                 if (ret == -EINPROGRESS)
1218                     num_async++;
1219                 else if (ret == -ENOMEM)
1220                     goto wait_for_memory;
1221                 else if (ret != -EAGAIN) {
1222                     if (ret == -ENOSPC)
1223                         ret = 0;
1224                     goto sendpage_end;
1225                 }
1226             }
1227         }
1228         continue;
1229 wait_for_sndbuf:
1230         set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1231 wait_for_memory:
1232         ret = sk_stream_wait_memory(sk, &timeo);
1233         if (ret) {
1234             if (ctx->open_rec)
1235                 tls_trim_both_msgs(sk, msg_pl->sg.size);
1236             goto sendpage_end;
1237         }
1238 
1239         if (ctx->open_rec)
1240             goto alloc_payload;
1241     }
1242 
1243     if (num_async) {
1244         /* Transmit if any encryptions have completed */
1245         if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1246             cancel_delayed_work(&ctx->tx_work.work);
1247             tls_tx_records(sk, flags);
1248         }
1249     }
1250 sendpage_end:
1251     ret = sk_stream_error(sk, flags, ret);
1252     return copied > 0 ? copied : ret;
1253 }
1254 
1255 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1256                int offset, size_t size, int flags)
1257 {
1258     if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1259               MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1260               MSG_NO_SHARED_FRAGS))
1261         return -EOPNOTSUPP;
1262 
1263     return tls_sw_do_sendpage(sk, page, offset, size, flags);
1264 }
1265 
1266 int tls_sw_sendpage(struct sock *sk, struct page *page,
1267             int offset, size_t size, int flags)
1268 {
1269     struct tls_context *tls_ctx = tls_get_ctx(sk);
1270     int ret;
1271 
1272     if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1273               MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1274         return -EOPNOTSUPP;
1275 
1276     mutex_lock(&tls_ctx->tx_lock);
1277     lock_sock(sk);
1278     ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1279     release_sock(sk);
1280     mutex_unlock(&tls_ctx->tx_lock);
1281     return ret;
1282 }
1283 
1284 static int
1285 tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
1286         bool released)
1287 {
1288     struct tls_context *tls_ctx = tls_get_ctx(sk);
1289     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1290     DEFINE_WAIT_FUNC(wait, woken_wake_function);
1291     long timeo;
1292 
1293     timeo = sock_rcvtimeo(sk, nonblock);
1294 
1295     while (!tls_strp_msg_ready(ctx)) {
1296         if (!sk_psock_queue_empty(psock))
1297             return 0;
1298 
1299         if (sk->sk_err)
1300             return sock_error(sk);
1301 
1302         if (!skb_queue_empty(&sk->sk_receive_queue)) {
1303             tls_strp_check_rcv(&ctx->strp);
1304             if (tls_strp_msg_ready(ctx))
1305                 break;
1306         }
1307 
1308         if (sk->sk_shutdown & RCV_SHUTDOWN)
1309             return 0;
1310 
1311         if (sock_flag(sk, SOCK_DONE))
1312             return 0;
1313 
1314         if (!timeo)
1315             return -EAGAIN;
1316 
1317         released = true;
1318         add_wait_queue(sk_sleep(sk), &wait);
1319         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1320         sk_wait_event(sk, &timeo,
1321                   tls_strp_msg_ready(ctx) ||
1322                   !sk_psock_queue_empty(psock),
1323                   &wait);
1324         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1325         remove_wait_queue(sk_sleep(sk), &wait);
1326 
1327         /* Handle signals */
1328         if (signal_pending(current))
1329             return sock_intr_errno(timeo);
1330     }
1331 
1332     tls_strp_msg_load(&ctx->strp, released);
1333 
1334     return 1;
1335 }
1336 
1337 static int tls_setup_from_iter(struct iov_iter *from,
1338                    int length, int *pages_used,
1339                    struct scatterlist *to,
1340                    int to_max_pages)
1341 {
1342     int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1343     struct page *pages[MAX_SKB_FRAGS];
1344     unsigned int size = 0;
1345     ssize_t copied, use;
1346     size_t offset;
1347 
1348     while (length > 0) {
1349         i = 0;
1350         maxpages = to_max_pages - num_elem;
1351         if (maxpages == 0) {
1352             rc = -EFAULT;
1353             goto out;
1354         }
1355         copied = iov_iter_get_pages2(from, pages,
1356                         length,
1357                         maxpages, &offset);
1358         if (copied <= 0) {
1359             rc = -EFAULT;
1360             goto out;
1361         }
1362 
1363         length -= copied;
1364         size += copied;
1365         while (copied) {
1366             use = min_t(int, copied, PAGE_SIZE - offset);
1367 
1368             sg_set_page(&to[num_elem],
1369                     pages[i], use, offset);
1370             sg_unmark_end(&to[num_elem]);
1371             /* We do not uncharge memory from this API */
1372 
1373             offset = 0;
1374             copied -= use;
1375 
1376             i++;
1377             num_elem++;
1378         }
1379     }
1380     /* Mark the end in the last sg entry if newly added */
1381     if (num_elem > *pages_used)
1382         sg_mark_end(&to[num_elem - 1]);
1383 out:
1384     if (rc)
1385         iov_iter_revert(from, size);
1386     *pages_used = num_elem;
1387 
1388     return rc;
1389 }
1390 
1391 static struct sk_buff *
1392 tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
1393              unsigned int full_len)
1394 {
1395     struct strp_msg *clr_rxm;
1396     struct sk_buff *clr_skb;
1397     int err;
1398 
1399     clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
1400                        &err, sk->sk_allocation);
1401     if (!clr_skb)
1402         return NULL;
1403 
1404     skb_copy_header(clr_skb, skb);
1405     clr_skb->len = full_len;
1406     clr_skb->data_len = full_len;
1407 
1408     clr_rxm = strp_msg(clr_skb);
1409     clr_rxm->offset = 0;
1410 
1411     return clr_skb;
1412 }
1413 
1414 /* Decrypt handlers
1415  *
1416  * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers.
1417  * They must transform the darg in/out argument are as follows:
1418  *       |          Input            |         Output
1419  * -------------------------------------------------------------------
1420  *    zc | Zero-copy decrypt allowed | Zero-copy performed
1421  * async | Async decrypt allowed     | Async crypto used / in progress
1422  *   skb |            *              | Output skb
1423  *
1424  * If ZC decryption was performed darg.skb will point to the input skb.
1425  */
1426 
1427 /* This function decrypts the input skb into either out_iov or in out_sg
1428  * or in skb buffers itself. The input parameter 'darg->zc' indicates if
1429  * zero-copy mode needs to be tried or not. With zero-copy mode, either
1430  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1431  * NULL, then the decryption happens inside skb buffers itself, i.e.
1432  * zero-copy gets disabled and 'darg->zc' is updated.
1433  */
1434 static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
1435               struct scatterlist *out_sg,
1436               struct tls_decrypt_arg *darg)
1437 {
1438     struct tls_context *tls_ctx = tls_get_ctx(sk);
1439     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1440     struct tls_prot_info *prot = &tls_ctx->prot_info;
1441     int n_sgin, n_sgout, aead_size, err, pages = 0;
1442     struct sk_buff *skb = tls_strp_msg(ctx);
1443     const struct strp_msg *rxm = strp_msg(skb);
1444     const struct tls_msg *tlm = tls_msg(skb);
1445     struct aead_request *aead_req;
1446     struct scatterlist *sgin = NULL;
1447     struct scatterlist *sgout = NULL;
1448     const int data_len = rxm->full_len - prot->overhead_size;
1449     int tail_pages = !!prot->tail_size;
1450     struct tls_decrypt_ctx *dctx;
1451     struct sk_buff *clear_skb;
1452     int iv_offset = 0;
1453     u8 *mem;
1454 
1455     n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1456              rxm->full_len - prot->prepend_size);
1457     if (n_sgin < 1)
1458         return n_sgin ?: -EBADMSG;
1459 
1460     if (darg->zc && (out_iov || out_sg)) {
1461         clear_skb = NULL;
1462 
1463         if (out_iov)
1464             n_sgout = 1 + tail_pages +
1465                 iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1466         else
1467             n_sgout = sg_nents(out_sg);
1468     } else {
1469         darg->zc = false;
1470 
1471         clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
1472         if (!clear_skb)
1473             return -ENOMEM;
1474 
1475         n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
1476     }
1477 
1478     /* Increment to accommodate AAD */
1479     n_sgin = n_sgin + 1;
1480 
1481     /* Allocate a single block of memory which contains
1482      *   aead_req || tls_decrypt_ctx.
1483      * Both structs are variable length.
1484      */
1485     aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1486     mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
1487               sk->sk_allocation);
1488     if (!mem) {
1489         err = -ENOMEM;
1490         goto exit_free_skb;
1491     }
1492 
1493     /* Segment the allocated memory */
1494     aead_req = (struct aead_request *)mem;
1495     dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
1496     sgin = &dctx->sg[0];
1497     sgout = &dctx->sg[n_sgin];
1498 
1499     /* For CCM based ciphers, first byte of nonce+iv is a constant */
1500     switch (prot->cipher_type) {
1501     case TLS_CIPHER_AES_CCM_128:
1502         dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1503         iv_offset = 1;
1504         break;
1505     case TLS_CIPHER_SM4_CCM:
1506         dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
1507         iv_offset = 1;
1508         break;
1509     }
1510 
1511     /* Prepare IV */
1512     if (prot->version == TLS_1_3_VERSION ||
1513         prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1514         memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
1515                prot->iv_size + prot->salt_size);
1516     } else {
1517         err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1518                     &dctx->iv[iv_offset] + prot->salt_size,
1519                     prot->iv_size);
1520         if (err < 0)
1521             goto exit_free;
1522         memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
1523     }
1524     tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
1525 
1526     /* Prepare AAD */
1527     tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
1528              prot->tail_size,
1529              tls_ctx->rx.rec_seq, tlm->control, prot);
1530 
1531     /* Prepare sgin */
1532     sg_init_table(sgin, n_sgin);
1533     sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
1534     err = skb_to_sgvec(skb, &sgin[1],
1535                rxm->offset + prot->prepend_size,
1536                rxm->full_len - prot->prepend_size);
1537     if (err < 0)
1538         goto exit_free;
1539 
1540     if (clear_skb) {
1541         sg_init_table(sgout, n_sgout);
1542         sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1543 
1544         err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
1545                    data_len + prot->tail_size);
1546         if (err < 0)
1547             goto exit_free;
1548     } else if (out_iov) {
1549         sg_init_table(sgout, n_sgout);
1550         sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1551 
1552         err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
1553                       (n_sgout - 1 - tail_pages));
1554         if (err < 0)
1555             goto exit_free_pages;
1556 
1557         if (prot->tail_size) {
1558             sg_unmark_end(&sgout[pages]);
1559             sg_set_buf(&sgout[pages + 1], &dctx->tail,
1560                    prot->tail_size);
1561             sg_mark_end(&sgout[pages + 1]);
1562         }
1563     } else if (out_sg) {
1564         memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1565     }
1566 
1567     /* Prepare and submit AEAD request */
1568     err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
1569                 data_len + prot->tail_size, aead_req, darg);
1570     if (err)
1571         goto exit_free_pages;
1572 
1573     darg->skb = clear_skb ?: tls_strp_msg(ctx);
1574     clear_skb = NULL;
1575 
1576     if (unlikely(darg->async)) {
1577         err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
1578         if (err)
1579             __skb_queue_tail(&ctx->async_hold, darg->skb);
1580         return err;
1581     }
1582 
1583     if (prot->tail_size)
1584         darg->tail = dctx->tail;
1585 
1586 exit_free_pages:
1587     /* Release the pages in case iov was mapped to pages */
1588     for (; pages > 0; pages--)
1589         put_page(sg_page(&sgout[pages]));
1590 exit_free:
1591     kfree(mem);
1592 exit_free_skb:
1593     consume_skb(clear_skb);
1594     return err;
1595 }
1596 
1597 static int
1598 tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
1599            struct msghdr *msg, struct tls_decrypt_arg *darg)
1600 {
1601     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1602     struct tls_prot_info *prot = &tls_ctx->prot_info;
1603     struct strp_msg *rxm;
1604     int pad, err;
1605 
1606     err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
1607     if (err < 0) {
1608         if (err == -EBADMSG)
1609             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
1610         return err;
1611     }
1612     /* keep going even for ->async, the code below is TLS 1.3 */
1613 
1614     /* If opportunistic TLS 1.3 ZC failed retry without ZC */
1615     if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
1616              darg->tail != TLS_RECORD_TYPE_DATA)) {
1617         darg->zc = false;
1618         if (!darg->tail)
1619             TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
1620         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
1621         return tls_decrypt_sw(sk, tls_ctx, msg, darg);
1622     }
1623 
1624     pad = tls_padding_length(prot, darg->skb, darg);
1625     if (pad < 0) {
1626         if (darg->skb != tls_strp_msg(ctx))
1627             consume_skb(darg->skb);
1628         return pad;
1629     }
1630 
1631     rxm = strp_msg(darg->skb);
1632     rxm->full_len -= pad;
1633 
1634     return 0;
1635 }
1636 
1637 static int
1638 tls_decrypt_device(struct sock *sk, struct msghdr *msg,
1639            struct tls_context *tls_ctx, struct tls_decrypt_arg *darg)
1640 {
1641     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1642     struct tls_prot_info *prot = &tls_ctx->prot_info;
1643     struct strp_msg *rxm;
1644     int pad, err;
1645 
1646     if (tls_ctx->rx_conf != TLS_HW)
1647         return 0;
1648 
1649     err = tls_device_decrypted(sk, tls_ctx);
1650     if (err <= 0)
1651         return err;
1652 
1653     pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
1654     if (pad < 0)
1655         return pad;
1656 
1657     darg->async = false;
1658     darg->skb = tls_strp_msg(ctx);
1659     /* ->zc downgrade check, in case TLS 1.3 gets here */
1660     darg->zc &= !(prot->version == TLS_1_3_VERSION &&
1661               tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA);
1662 
1663     rxm = strp_msg(darg->skb);
1664     rxm->full_len -= pad;
1665 
1666     if (!darg->zc) {
1667         /* Non-ZC case needs a real skb */
1668         darg->skb = tls_strp_msg_detach(ctx);
1669         if (!darg->skb)
1670             return -ENOMEM;
1671     } else {
1672         unsigned int off, len;
1673 
1674         /* In ZC case nobody cares about the output skb.
1675          * Just copy the data here. Note the skb is not fully trimmed.
1676          */
1677         off = rxm->offset + prot->prepend_size;
1678         len = rxm->full_len - prot->overhead_size;
1679 
1680         err = skb_copy_datagram_msg(darg->skb, off, msg, len);
1681         if (err)
1682             return err;
1683     }
1684     return 1;
1685 }
1686 
1687 static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
1688                  struct tls_decrypt_arg *darg)
1689 {
1690     struct tls_context *tls_ctx = tls_get_ctx(sk);
1691     struct tls_prot_info *prot = &tls_ctx->prot_info;
1692     struct strp_msg *rxm;
1693     int err;
1694 
1695     err = tls_decrypt_device(sk, msg, tls_ctx, darg);
1696     if (!err)
1697         err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
1698     if (err < 0)
1699         return err;
1700 
1701     rxm = strp_msg(darg->skb);
1702     rxm->offset += prot->prepend_size;
1703     rxm->full_len -= prot->overhead_size;
1704     tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1705 
1706     return 0;
1707 }
1708 
1709 int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
1710 {
1711     struct tls_decrypt_arg darg = { .zc = true, };
1712 
1713     return tls_decrypt_sg(sk, NULL, sgout, &darg);
1714 }
1715 
1716 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
1717                    u8 *control)
1718 {
1719     int err;
1720 
1721     if (!*control) {
1722         *control = tlm->control;
1723         if (!*control)
1724             return -EBADMSG;
1725 
1726         err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1727                    sizeof(*control), control);
1728         if (*control != TLS_RECORD_TYPE_DATA) {
1729             if (err || msg->msg_flags & MSG_CTRUNC)
1730                 return -EIO;
1731         }
1732     } else if (*control != tlm->control) {
1733         return 0;
1734     }
1735 
1736     return 1;
1737 }
1738 
1739 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
1740 {
1741     tls_strp_msg_done(&ctx->strp);
1742 }
1743 
1744 /* This function traverses the rx_list in tls receive context to copies the
1745  * decrypted records into the buffer provided by caller zero copy is not
1746  * true. Further, the records are removed from the rx_list if it is not a peek
1747  * case and the record has been consumed completely.
1748  */
1749 static int process_rx_list(struct tls_sw_context_rx *ctx,
1750                struct msghdr *msg,
1751                u8 *control,
1752                size_t skip,
1753                size_t len,
1754                bool is_peek)
1755 {
1756     struct sk_buff *skb = skb_peek(&ctx->rx_list);
1757     struct tls_msg *tlm;
1758     ssize_t copied = 0;
1759     int err;
1760 
1761     while (skip && skb) {
1762         struct strp_msg *rxm = strp_msg(skb);
1763         tlm = tls_msg(skb);
1764 
1765         err = tls_record_content_type(msg, tlm, control);
1766         if (err <= 0)
1767             goto out;
1768 
1769         if (skip < rxm->full_len)
1770             break;
1771 
1772         skip = skip - rxm->full_len;
1773         skb = skb_peek_next(skb, &ctx->rx_list);
1774     }
1775 
1776     while (len && skb) {
1777         struct sk_buff *next_skb;
1778         struct strp_msg *rxm = strp_msg(skb);
1779         int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1780 
1781         tlm = tls_msg(skb);
1782 
1783         err = tls_record_content_type(msg, tlm, control);
1784         if (err <= 0)
1785             goto out;
1786 
1787         err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1788                         msg, chunk);
1789         if (err < 0)
1790             goto out;
1791 
1792         len = len - chunk;
1793         copied = copied + chunk;
1794 
1795         /* Consume the data from record if it is non-peek case*/
1796         if (!is_peek) {
1797             rxm->offset = rxm->offset + chunk;
1798             rxm->full_len = rxm->full_len - chunk;
1799 
1800             /* Return if there is unconsumed data in the record */
1801             if (rxm->full_len - skip)
1802                 break;
1803         }
1804 
1805         /* The remaining skip-bytes must lie in 1st record in rx_list.
1806          * So from the 2nd record, 'skip' should be 0.
1807          */
1808         skip = 0;
1809 
1810         if (msg)
1811             msg->msg_flags |= MSG_EOR;
1812 
1813         next_skb = skb_peek_next(skb, &ctx->rx_list);
1814 
1815         if (!is_peek) {
1816             __skb_unlink(skb, &ctx->rx_list);
1817             consume_skb(skb);
1818         }
1819 
1820         skb = next_skb;
1821     }
1822     err = 0;
1823 
1824 out:
1825     return copied ? : err;
1826 }
1827 
1828 static bool
1829 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
1830                size_t len_left, size_t decrypted, ssize_t done,
1831                size_t *flushed_at)
1832 {
1833     size_t max_rec;
1834 
1835     if (len_left <= decrypted)
1836         return false;
1837 
1838     max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
1839     if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
1840         return false;
1841 
1842     *flushed_at = done;
1843     return sk_flush_backlog(sk);
1844 }
1845 
1846 static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
1847                   bool nonblock)
1848 {
1849     long timeo;
1850     int err;
1851 
1852     lock_sock(sk);
1853 
1854     timeo = sock_rcvtimeo(sk, nonblock);
1855 
1856     while (unlikely(ctx->reader_present)) {
1857         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1858 
1859         ctx->reader_contended = 1;
1860 
1861         add_wait_queue(&ctx->wq, &wait);
1862         sk_wait_event(sk, &timeo,
1863                   !READ_ONCE(ctx->reader_present), &wait);
1864         remove_wait_queue(&ctx->wq, &wait);
1865 
1866         if (timeo <= 0) {
1867             err = -EAGAIN;
1868             goto err_unlock;
1869         }
1870         if (signal_pending(current)) {
1871             err = sock_intr_errno(timeo);
1872             goto err_unlock;
1873         }
1874     }
1875 
1876     WRITE_ONCE(ctx->reader_present, 1);
1877 
1878     return 0;
1879 
1880 err_unlock:
1881     release_sock(sk);
1882     return err;
1883 }
1884 
1885 static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
1886 {
1887     if (unlikely(ctx->reader_contended)) {
1888         if (wq_has_sleeper(&ctx->wq))
1889             wake_up(&ctx->wq);
1890         else
1891             ctx->reader_contended = 0;
1892 
1893         WARN_ON_ONCE(!ctx->reader_present);
1894     }
1895 
1896     WRITE_ONCE(ctx->reader_present, 0);
1897     release_sock(sk);
1898 }
1899 
1900 int tls_sw_recvmsg(struct sock *sk,
1901            struct msghdr *msg,
1902            size_t len,
1903            int flags,
1904            int *addr_len)
1905 {
1906     struct tls_context *tls_ctx = tls_get_ctx(sk);
1907     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1908     struct tls_prot_info *prot = &tls_ctx->prot_info;
1909     ssize_t decrypted = 0, async_copy_bytes = 0;
1910     struct sk_psock *psock;
1911     unsigned char control = 0;
1912     size_t flushed_at = 0;
1913     struct strp_msg *rxm;
1914     struct tls_msg *tlm;
1915     ssize_t copied = 0;
1916     bool async = false;
1917     int target, err;
1918     bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1919     bool is_peek = flags & MSG_PEEK;
1920     bool released = true;
1921     bool bpf_strp_enabled;
1922     bool zc_capable;
1923 
1924     if (unlikely(flags & MSG_ERRQUEUE))
1925         return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1926 
1927     psock = sk_psock_get(sk);
1928     err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
1929     if (err < 0)
1930         return err;
1931     bpf_strp_enabled = sk_psock_strp_enabled(psock);
1932 
1933     /* If crypto failed the connection is broken */
1934     err = ctx->async_wait.err;
1935     if (err)
1936         goto end;
1937 
1938     /* Process pending decrypted records. It must be non-zero-copy */
1939     err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
1940     if (err < 0)
1941         goto end;
1942 
1943     copied = err;
1944     if (len <= copied)
1945         goto end;
1946 
1947     target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1948     len = len - copied;
1949 
1950     zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
1951         ctx->zc_capable;
1952     decrypted = 0;
1953     while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) {
1954         struct tls_decrypt_arg darg;
1955         int to_decrypt, chunk;
1956 
1957         err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
1958                       released);
1959         if (err <= 0) {
1960             if (psock) {
1961                 chunk = sk_msg_recvmsg(sk, psock, msg, len,
1962                                flags);
1963                 if (chunk > 0) {
1964                     decrypted += chunk;
1965                     len -= chunk;
1966                     continue;
1967                 }
1968             }
1969             goto recv_end;
1970         }
1971 
1972         memset(&darg.inargs, 0, sizeof(darg.inargs));
1973 
1974         rxm = strp_msg(tls_strp_msg(ctx));
1975         tlm = tls_msg(tls_strp_msg(ctx));
1976 
1977         to_decrypt = rxm->full_len - prot->overhead_size;
1978 
1979         if (zc_capable && to_decrypt <= len &&
1980             tlm->control == TLS_RECORD_TYPE_DATA)
1981             darg.zc = true;
1982 
1983         /* Do not use async mode if record is non-data */
1984         if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1985             darg.async = ctx->async_capable;
1986         else
1987             darg.async = false;
1988 
1989         err = tls_rx_one_record(sk, msg, &darg);
1990         if (err < 0) {
1991             tls_err_abort(sk, -EBADMSG);
1992             goto recv_end;
1993         }
1994 
1995         async |= darg.async;
1996 
1997         /* If the type of records being processed is not known yet,
1998          * set it to record type just dequeued. If it is already known,
1999          * but does not match the record type just dequeued, go to end.
2000          * We always get record type here since for tls1.2, record type
2001          * is known just after record is dequeued from stream parser.
2002          * For tls1.3, we disable async.
2003          */
2004         err = tls_record_content_type(msg, tls_msg(darg.skb), &control);
2005         if (err <= 0) {
2006             DEBUG_NET_WARN_ON_ONCE(darg.zc);
2007             tls_rx_rec_done(ctx);
2008 put_on_rx_list_err:
2009             __skb_queue_tail(&ctx->rx_list, darg.skb);
2010             goto recv_end;
2011         }
2012 
2013         /* periodically flush backlog, and feed strparser */
2014         released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
2015                           decrypted + copied,
2016                           &flushed_at);
2017 
2018         /* TLS 1.3 may have updated the length by more than overhead */
2019         rxm = strp_msg(darg.skb);
2020         chunk = rxm->full_len;
2021         tls_rx_rec_done(ctx);
2022 
2023         if (!darg.zc) {
2024             bool partially_consumed = chunk > len;
2025             struct sk_buff *skb = darg.skb;
2026 
2027             DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
2028 
2029             if (async) {
2030                 /* TLS 1.2-only, to_decrypt must be text len */
2031                 chunk = min_t(int, to_decrypt, len);
2032                 async_copy_bytes += chunk;
2033 put_on_rx_list:
2034                 decrypted += chunk;
2035                 len -= chunk;
2036                 __skb_queue_tail(&ctx->rx_list, skb);
2037                 continue;
2038             }
2039 
2040             if (bpf_strp_enabled) {
2041                 released = true;
2042                 err = sk_psock_tls_strp_read(psock, skb);
2043                 if (err != __SK_PASS) {
2044                     rxm->offset = rxm->offset + rxm->full_len;
2045                     rxm->full_len = 0;
2046                     if (err == __SK_DROP)
2047                         consume_skb(skb);
2048                     continue;
2049                 }
2050             }
2051 
2052             if (partially_consumed)
2053                 chunk = len;
2054 
2055             err = skb_copy_datagram_msg(skb, rxm->offset,
2056                             msg, chunk);
2057             if (err < 0)
2058                 goto put_on_rx_list_err;
2059 
2060             if (is_peek)
2061                 goto put_on_rx_list;
2062 
2063             if (partially_consumed) {
2064                 rxm->offset += chunk;
2065                 rxm->full_len -= chunk;
2066                 goto put_on_rx_list;
2067             }
2068 
2069             consume_skb(skb);
2070         }
2071 
2072         decrypted += chunk;
2073         len -= chunk;
2074 
2075         /* Return full control message to userspace before trying
2076          * to parse another message type
2077          */
2078         msg->msg_flags |= MSG_EOR;
2079         if (control != TLS_RECORD_TYPE_DATA)
2080             break;
2081     }
2082 
2083 recv_end:
2084     if (async) {
2085         int ret, pending;
2086 
2087         /* Wait for all previously submitted records to be decrypted */
2088         spin_lock_bh(&ctx->decrypt_compl_lock);
2089         reinit_completion(&ctx->async_wait.completion);
2090         pending = atomic_read(&ctx->decrypt_pending);
2091         spin_unlock_bh(&ctx->decrypt_compl_lock);
2092         ret = 0;
2093         if (pending)
2094             ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2095         __skb_queue_purge(&ctx->async_hold);
2096 
2097         if (ret) {
2098             if (err >= 0 || err == -EINPROGRESS)
2099                 err = ret;
2100             decrypted = 0;
2101             goto end;
2102         }
2103 
2104         /* Drain records from the rx_list & copy if required */
2105         if (is_peek || is_kvec)
2106             err = process_rx_list(ctx, msg, &control, copied,
2107                           decrypted, is_peek);
2108         else
2109             err = process_rx_list(ctx, msg, &control, 0,
2110                           async_copy_bytes, is_peek);
2111         decrypted = max(err, 0);
2112     }
2113 
2114     copied += decrypted;
2115 
2116 end:
2117     tls_rx_reader_unlock(sk, ctx);
2118     if (psock)
2119         sk_psock_put(sk, psock);
2120     return copied ? : err;
2121 }
2122 
2123 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
2124                struct pipe_inode_info *pipe,
2125                size_t len, unsigned int flags)
2126 {
2127     struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
2128     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2129     struct strp_msg *rxm = NULL;
2130     struct sock *sk = sock->sk;
2131     struct tls_msg *tlm;
2132     struct sk_buff *skb;
2133     ssize_t copied = 0;
2134     int chunk;
2135     int err;
2136 
2137     err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
2138     if (err < 0)
2139         return err;
2140 
2141     if (!skb_queue_empty(&ctx->rx_list)) {
2142         skb = __skb_dequeue(&ctx->rx_list);
2143     } else {
2144         struct tls_decrypt_arg darg;
2145 
2146         err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
2147                       true);
2148         if (err <= 0)
2149             goto splice_read_end;
2150 
2151         memset(&darg.inargs, 0, sizeof(darg.inargs));
2152 
2153         err = tls_rx_one_record(sk, NULL, &darg);
2154         if (err < 0) {
2155             tls_err_abort(sk, -EBADMSG);
2156             goto splice_read_end;
2157         }
2158 
2159         tls_rx_rec_done(ctx);
2160         skb = darg.skb;
2161     }
2162 
2163     rxm = strp_msg(skb);
2164     tlm = tls_msg(skb);
2165 
2166     /* splice does not support reading control messages */
2167     if (tlm->control != TLS_RECORD_TYPE_DATA) {
2168         err = -EINVAL;
2169         goto splice_requeue;
2170     }
2171 
2172     chunk = min_t(unsigned int, rxm->full_len, len);
2173     copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2174     if (copied < 0)
2175         goto splice_requeue;
2176 
2177     if (chunk < rxm->full_len) {
2178         rxm->offset += len;
2179         rxm->full_len -= len;
2180         goto splice_requeue;
2181     }
2182 
2183     consume_skb(skb);
2184 
2185 splice_read_end:
2186     tls_rx_reader_unlock(sk, ctx);
2187     return copied ? : err;
2188 
2189 splice_requeue:
2190     __skb_queue_head(&ctx->rx_list, skb);
2191     goto splice_read_end;
2192 }
2193 
2194 bool tls_sw_sock_is_readable(struct sock *sk)
2195 {
2196     struct tls_context *tls_ctx = tls_get_ctx(sk);
2197     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2198     bool ingress_empty = true;
2199     struct sk_psock *psock;
2200 
2201     rcu_read_lock();
2202     psock = sk_psock(sk);
2203     if (psock)
2204         ingress_empty = list_empty(&psock->ingress_msg);
2205     rcu_read_unlock();
2206 
2207     return !ingress_empty || tls_strp_msg_ready(ctx) ||
2208         !skb_queue_empty(&ctx->rx_list);
2209 }
2210 
2211 int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
2212 {
2213     struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2214     struct tls_prot_info *prot = &tls_ctx->prot_info;
2215     char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2216     size_t cipher_overhead;
2217     size_t data_len = 0;
2218     int ret;
2219 
2220     /* Verify that we have a full TLS header, or wait for more data */
2221     if (strp->stm.offset + prot->prepend_size > skb->len)
2222         return 0;
2223 
2224     /* Sanity-check size of on-stack buffer. */
2225     if (WARN_ON(prot->prepend_size > sizeof(header))) {
2226         ret = -EINVAL;
2227         goto read_failure;
2228     }
2229 
2230     /* Linearize header to local buffer */
2231     ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
2232     if (ret < 0)
2233         goto read_failure;
2234 
2235     strp->mark = header[0];
2236 
2237     data_len = ((header[4] & 0xFF) | (header[3] << 8));
2238 
2239     cipher_overhead = prot->tag_size;
2240     if (prot->version != TLS_1_3_VERSION &&
2241         prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2242         cipher_overhead += prot->iv_size;
2243 
2244     if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2245         prot->tail_size) {
2246         ret = -EMSGSIZE;
2247         goto read_failure;
2248     }
2249     if (data_len < cipher_overhead) {
2250         ret = -EBADMSG;
2251         goto read_failure;
2252     }
2253 
2254     /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2255     if (header[1] != TLS_1_2_VERSION_MINOR ||
2256         header[2] != TLS_1_2_VERSION_MAJOR) {
2257         ret = -EINVAL;
2258         goto read_failure;
2259     }
2260 
2261     tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2262                      TCP_SKB_CB(skb)->seq + strp->stm.offset);
2263     return data_len + TLS_HEADER_SIZE;
2264 
2265 read_failure:
2266     tls_err_abort(strp->sk, ret);
2267 
2268     return ret;
2269 }
2270 
2271 void tls_rx_msg_ready(struct tls_strparser *strp)
2272 {
2273     struct tls_sw_context_rx *ctx;
2274 
2275     ctx = container_of(strp, struct tls_sw_context_rx, strp);
2276     ctx->saved_data_ready(strp->sk);
2277 }
2278 
2279 static void tls_data_ready(struct sock *sk)
2280 {
2281     struct tls_context *tls_ctx = tls_get_ctx(sk);
2282     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2283     struct sk_psock *psock;
2284 
2285     tls_strp_data_ready(&ctx->strp);
2286 
2287     psock = sk_psock_get(sk);
2288     if (psock) {
2289         if (!list_empty(&psock->ingress_msg))
2290             ctx->saved_data_ready(sk);
2291         sk_psock_put(sk, psock);
2292     }
2293 }
2294 
2295 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2296 {
2297     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2298 
2299     set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2300     set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2301     cancel_delayed_work_sync(&ctx->tx_work.work);
2302 }
2303 
2304 void tls_sw_release_resources_tx(struct sock *sk)
2305 {
2306     struct tls_context *tls_ctx = tls_get_ctx(sk);
2307     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2308     struct tls_rec *rec, *tmp;
2309     int pending;
2310 
2311     /* Wait for any pending async encryptions to complete */
2312     spin_lock_bh(&ctx->encrypt_compl_lock);
2313     ctx->async_notify = true;
2314     pending = atomic_read(&ctx->encrypt_pending);
2315     spin_unlock_bh(&ctx->encrypt_compl_lock);
2316 
2317     if (pending)
2318         crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2319 
2320     tls_tx_records(sk, -1);
2321 
2322     /* Free up un-sent records in tx_list. First, free
2323      * the partially sent record if any at head of tx_list.
2324      */
2325     if (tls_ctx->partially_sent_record) {
2326         tls_free_partial_record(sk, tls_ctx);
2327         rec = list_first_entry(&ctx->tx_list,
2328                        struct tls_rec, list);
2329         list_del(&rec->list);
2330         sk_msg_free(sk, &rec->msg_plaintext);
2331         kfree(rec);
2332     }
2333 
2334     list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2335         list_del(&rec->list);
2336         sk_msg_free(sk, &rec->msg_encrypted);
2337         sk_msg_free(sk, &rec->msg_plaintext);
2338         kfree(rec);
2339     }
2340 
2341     crypto_free_aead(ctx->aead_send);
2342     tls_free_open_rec(sk);
2343 }
2344 
2345 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2346 {
2347     struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2348 
2349     kfree(ctx);
2350 }
2351 
2352 void tls_sw_release_resources_rx(struct sock *sk)
2353 {
2354     struct tls_context *tls_ctx = tls_get_ctx(sk);
2355     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2356 
2357     kfree(tls_ctx->rx.rec_seq);
2358     kfree(tls_ctx->rx.iv);
2359 
2360     if (ctx->aead_recv) {
2361         __skb_queue_purge(&ctx->rx_list);
2362         crypto_free_aead(ctx->aead_recv);
2363         tls_strp_stop(&ctx->strp);
2364         /* If tls_sw_strparser_arm() was not called (cleanup paths)
2365          * we still want to tls_strp_stop(), but sk->sk_data_ready was
2366          * never swapped.
2367          */
2368         if (ctx->saved_data_ready) {
2369             write_lock_bh(&sk->sk_callback_lock);
2370             sk->sk_data_ready = ctx->saved_data_ready;
2371             write_unlock_bh(&sk->sk_callback_lock);
2372         }
2373     }
2374 }
2375 
2376 void tls_sw_strparser_done(struct tls_context *tls_ctx)
2377 {
2378     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2379 
2380     tls_strp_done(&ctx->strp);
2381 }
2382 
2383 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2384 {
2385     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2386 
2387     kfree(ctx);
2388 }
2389 
2390 void tls_sw_free_resources_rx(struct sock *sk)
2391 {
2392     struct tls_context *tls_ctx = tls_get_ctx(sk);
2393 
2394     tls_sw_release_resources_rx(sk);
2395     tls_sw_free_ctx_rx(tls_ctx);
2396 }
2397 
2398 /* The work handler to transmitt the encrypted records in tx_list */
2399 static void tx_work_handler(struct work_struct *work)
2400 {
2401     struct delayed_work *delayed_work = to_delayed_work(work);
2402     struct tx_work *tx_work = container_of(delayed_work,
2403                            struct tx_work, work);
2404     struct sock *sk = tx_work->sk;
2405     struct tls_context *tls_ctx = tls_get_ctx(sk);
2406     struct tls_sw_context_tx *ctx;
2407 
2408     if (unlikely(!tls_ctx))
2409         return;
2410 
2411     ctx = tls_sw_ctx_tx(tls_ctx);
2412     if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2413         return;
2414 
2415     if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2416         return;
2417     mutex_lock(&tls_ctx->tx_lock);
2418     lock_sock(sk);
2419     tls_tx_records(sk, -1);
2420     release_sock(sk);
2421     mutex_unlock(&tls_ctx->tx_lock);
2422 }
2423 
2424 static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
2425 {
2426     struct tls_rec *rec;
2427 
2428     rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
2429     if (!rec)
2430         return false;
2431 
2432     return READ_ONCE(rec->tx_ready);
2433 }
2434 
2435 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2436 {
2437     struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2438 
2439     /* Schedule the transmission if tx list is ready */
2440     if (tls_is_tx_ready(tx_ctx) &&
2441         !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2442         schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2443 }
2444 
2445 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2446 {
2447     struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2448 
2449     write_lock_bh(&sk->sk_callback_lock);
2450     rx_ctx->saved_data_ready = sk->sk_data_ready;
2451     sk->sk_data_ready = tls_data_ready;
2452     write_unlock_bh(&sk->sk_callback_lock);
2453 }
2454 
2455 void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
2456 {
2457     struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2458 
2459     rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
2460         tls_ctx->prot_info.version != TLS_1_3_VERSION;
2461 }
2462 
2463 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2464 {
2465     struct tls_context *tls_ctx = tls_get_ctx(sk);
2466     struct tls_prot_info *prot = &tls_ctx->prot_info;
2467     struct tls_crypto_info *crypto_info;
2468     struct tls_sw_context_tx *sw_ctx_tx = NULL;
2469     struct tls_sw_context_rx *sw_ctx_rx = NULL;
2470     struct cipher_context *cctx;
2471     struct crypto_aead **aead;
2472     u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2473     struct crypto_tfm *tfm;
2474     char *iv, *rec_seq, *key, *salt, *cipher_name;
2475     size_t keysize;
2476     int rc = 0;
2477 
2478     if (!ctx) {
2479         rc = -EINVAL;
2480         goto out;
2481     }
2482 
2483     if (tx) {
2484         if (!ctx->priv_ctx_tx) {
2485             sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2486             if (!sw_ctx_tx) {
2487                 rc = -ENOMEM;
2488                 goto out;
2489             }
2490             ctx->priv_ctx_tx = sw_ctx_tx;
2491         } else {
2492             sw_ctx_tx =
2493                 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2494         }
2495     } else {
2496         if (!ctx->priv_ctx_rx) {
2497             sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2498             if (!sw_ctx_rx) {
2499                 rc = -ENOMEM;
2500                 goto out;
2501             }
2502             ctx->priv_ctx_rx = sw_ctx_rx;
2503         } else {
2504             sw_ctx_rx =
2505                 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2506         }
2507     }
2508 
2509     if (tx) {
2510         crypto_init_wait(&sw_ctx_tx->async_wait);
2511         spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2512         crypto_info = &ctx->crypto_send.info;
2513         cctx = &ctx->tx;
2514         aead = &sw_ctx_tx->aead_send;
2515         INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2516         INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2517         sw_ctx_tx->tx_work.sk = sk;
2518     } else {
2519         crypto_init_wait(&sw_ctx_rx->async_wait);
2520         spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2521         init_waitqueue_head(&sw_ctx_rx->wq);
2522         crypto_info = &ctx->crypto_recv.info;
2523         cctx = &ctx->rx;
2524         skb_queue_head_init(&sw_ctx_rx->rx_list);
2525         skb_queue_head_init(&sw_ctx_rx->async_hold);
2526         aead = &sw_ctx_rx->aead_recv;
2527     }
2528 
2529     switch (crypto_info->cipher_type) {
2530     case TLS_CIPHER_AES_GCM_128: {
2531         struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2532 
2533         gcm_128_info = (void *)crypto_info;
2534         nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2535         tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2536         iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2537         iv = gcm_128_info->iv;
2538         rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2539         rec_seq = gcm_128_info->rec_seq;
2540         keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2541         key = gcm_128_info->key;
2542         salt = gcm_128_info->salt;
2543         salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2544         cipher_name = "gcm(aes)";
2545         break;
2546     }
2547     case TLS_CIPHER_AES_GCM_256: {
2548         struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2549 
2550         gcm_256_info = (void *)crypto_info;
2551         nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2552         tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2553         iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2554         iv = gcm_256_info->iv;
2555         rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2556         rec_seq = gcm_256_info->rec_seq;
2557         keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2558         key = gcm_256_info->key;
2559         salt = gcm_256_info->salt;
2560         salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2561         cipher_name = "gcm(aes)";
2562         break;
2563     }
2564     case TLS_CIPHER_AES_CCM_128: {
2565         struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2566 
2567         ccm_128_info = (void *)crypto_info;
2568         nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2569         tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2570         iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2571         iv = ccm_128_info->iv;
2572         rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2573         rec_seq = ccm_128_info->rec_seq;
2574         keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2575         key = ccm_128_info->key;
2576         salt = ccm_128_info->salt;
2577         salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2578         cipher_name = "ccm(aes)";
2579         break;
2580     }
2581     case TLS_CIPHER_CHACHA20_POLY1305: {
2582         struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
2583 
2584         chacha20_poly1305_info = (void *)crypto_info;
2585         nonce_size = 0;
2586         tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
2587         iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
2588         iv = chacha20_poly1305_info->iv;
2589         rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
2590         rec_seq = chacha20_poly1305_info->rec_seq;
2591         keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
2592         key = chacha20_poly1305_info->key;
2593         salt = chacha20_poly1305_info->salt;
2594         salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
2595         cipher_name = "rfc7539(chacha20,poly1305)";
2596         break;
2597     }
2598     case TLS_CIPHER_SM4_GCM: {
2599         struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;
2600 
2601         sm4_gcm_info = (void *)crypto_info;
2602         nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2603         tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
2604         iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2605         iv = sm4_gcm_info->iv;
2606         rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
2607         rec_seq = sm4_gcm_info->rec_seq;
2608         keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
2609         key = sm4_gcm_info->key;
2610         salt = sm4_gcm_info->salt;
2611         salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
2612         cipher_name = "gcm(sm4)";
2613         break;
2614     }
2615     case TLS_CIPHER_SM4_CCM: {
2616         struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;
2617 
2618         sm4_ccm_info = (void *)crypto_info;
2619         nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2620         tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
2621         iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2622         iv = sm4_ccm_info->iv;
2623         rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
2624         rec_seq = sm4_ccm_info->rec_seq;
2625         keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
2626         key = sm4_ccm_info->key;
2627         salt = sm4_ccm_info->salt;
2628         salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
2629         cipher_name = "ccm(sm4)";
2630         break;
2631     }
2632     default:
2633         rc = -EINVAL;
2634         goto free_priv;
2635     }
2636 
2637     if (crypto_info->version == TLS_1_3_VERSION) {
2638         nonce_size = 0;
2639         prot->aad_size = TLS_HEADER_SIZE;
2640         prot->tail_size = 1;
2641     } else {
2642         prot->aad_size = TLS_AAD_SPACE_SIZE;
2643         prot->tail_size = 0;
2644     }
2645 
2646     /* Sanity-check the sizes for stack allocations. */
2647     if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2648         rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
2649         prot->aad_size > TLS_MAX_AAD_SIZE) {
2650         rc = -EINVAL;
2651         goto free_priv;
2652     }
2653 
2654     prot->version = crypto_info->version;
2655     prot->cipher_type = crypto_info->cipher_type;
2656     prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2657     prot->tag_size = tag_size;
2658     prot->overhead_size = prot->prepend_size +
2659                   prot->tag_size + prot->tail_size;
2660     prot->iv_size = iv_size;
2661     prot->salt_size = salt_size;
2662     cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2663     if (!cctx->iv) {
2664         rc = -ENOMEM;
2665         goto free_priv;
2666     }
2667     /* Note: 128 & 256 bit salt are the same size */
2668     prot->rec_seq_size = rec_seq_size;
2669     memcpy(cctx->iv, salt, salt_size);
2670     memcpy(cctx->iv + salt_size, iv, iv_size);
2671     cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2672     if (!cctx->rec_seq) {
2673         rc = -ENOMEM;
2674         goto free_iv;
2675     }
2676 
2677     if (!*aead) {
2678         *aead = crypto_alloc_aead(cipher_name, 0, 0);
2679         if (IS_ERR(*aead)) {
2680             rc = PTR_ERR(*aead);
2681             *aead = NULL;
2682             goto free_rec_seq;
2683         }
2684     }
2685 
2686     ctx->push_pending_record = tls_sw_push_pending_record;
2687 
2688     rc = crypto_aead_setkey(*aead, key, keysize);
2689 
2690     if (rc)
2691         goto free_aead;
2692 
2693     rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2694     if (rc)
2695         goto free_aead;
2696 
2697     if (sw_ctx_rx) {
2698         tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2699 
2700         tls_update_rx_zc_capable(ctx);
2701         sw_ctx_rx->async_capable =
2702             crypto_info->version != TLS_1_3_VERSION &&
2703             !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
2704 
2705         rc = tls_strp_init(&sw_ctx_rx->strp, sk);
2706         if (rc)
2707             goto free_aead;
2708     }
2709 
2710     goto out;
2711 
2712 free_aead:
2713     crypto_free_aead(*aead);
2714     *aead = NULL;
2715 free_rec_seq:
2716     kfree(cctx->rec_seq);
2717     cctx->rec_seq = NULL;
2718 free_iv:
2719     kfree(cctx->iv);
2720     cctx->iv = NULL;
2721 free_priv:
2722     if (tx) {
2723         kfree(ctx->priv_ctx_tx);
2724         ctx->priv_ctx_tx = NULL;
2725     } else {
2726         kfree(ctx->priv_ctx_rx);
2727         ctx->priv_ctx_rx = NULL;
2728     }
2729 out:
2730     return rc;
2731 }