0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038 #include <linux/bug.h>
0039 #include <linux/sched/signal.h>
0040 #include <linux/module.h>
0041 #include <linux/splice.h>
0042 #include <crypto/aead.h>
0043
0044 #include <net/strparser.h>
0045 #include <net/tls.h>
0046
0047 #include "tls.h"
0048
0049 struct tls_decrypt_arg {
0050 struct_group(inargs,
0051 bool zc;
0052 bool async;
0053 u8 tail;
0054 );
0055
0056 struct sk_buff *skb;
0057 };
0058
0059 struct tls_decrypt_ctx {
0060 u8 iv[MAX_IV_SIZE];
0061 u8 aad[TLS_MAX_AAD_SIZE];
0062 u8 tail;
0063 struct scatterlist sg[];
0064 };
0065
0066 noinline void tls_err_abort(struct sock *sk, int err)
0067 {
0068 WARN_ON_ONCE(err >= 0);
0069
0070 sk->sk_err = -err;
0071 sk_error_report(sk);
0072 }
0073
0074 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
0075 unsigned int recursion_level)
0076 {
0077 int start = skb_headlen(skb);
0078 int i, chunk = start - offset;
0079 struct sk_buff *frag_iter;
0080 int elt = 0;
0081
0082 if (unlikely(recursion_level >= 24))
0083 return -EMSGSIZE;
0084
0085 if (chunk > 0) {
0086 if (chunk > len)
0087 chunk = len;
0088 elt++;
0089 len -= chunk;
0090 if (len == 0)
0091 return elt;
0092 offset += chunk;
0093 }
0094
0095 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
0096 int end;
0097
0098 WARN_ON(start > offset + len);
0099
0100 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
0101 chunk = end - offset;
0102 if (chunk > 0) {
0103 if (chunk > len)
0104 chunk = len;
0105 elt++;
0106 len -= chunk;
0107 if (len == 0)
0108 return elt;
0109 offset += chunk;
0110 }
0111 start = end;
0112 }
0113
0114 if (unlikely(skb_has_frag_list(skb))) {
0115 skb_walk_frags(skb, frag_iter) {
0116 int end, ret;
0117
0118 WARN_ON(start > offset + len);
0119
0120 end = start + frag_iter->len;
0121 chunk = end - offset;
0122 if (chunk > 0) {
0123 if (chunk > len)
0124 chunk = len;
0125 ret = __skb_nsg(frag_iter, offset - start, chunk,
0126 recursion_level + 1);
0127 if (unlikely(ret < 0))
0128 return ret;
0129 elt += ret;
0130 len -= chunk;
0131 if (len == 0)
0132 return elt;
0133 offset += chunk;
0134 }
0135 start = end;
0136 }
0137 }
0138 BUG_ON(len);
0139 return elt;
0140 }
0141
0142
0143
0144
0145 static int skb_nsg(struct sk_buff *skb, int offset, int len)
0146 {
0147 return __skb_nsg(skb, offset, len, 0);
0148 }
0149
0150 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
0151 struct tls_decrypt_arg *darg)
0152 {
0153 struct strp_msg *rxm = strp_msg(skb);
0154 struct tls_msg *tlm = tls_msg(skb);
0155 int sub = 0;
0156
0157
0158 if (prot->version == TLS_1_3_VERSION) {
0159 int offset = rxm->full_len - TLS_TAG_SIZE - 1;
0160 char content_type = darg->zc ? darg->tail : 0;
0161 int err;
0162
0163 while (content_type == 0) {
0164 if (offset < prot->prepend_size)
0165 return -EBADMSG;
0166 err = skb_copy_bits(skb, rxm->offset + offset,
0167 &content_type, 1);
0168 if (err)
0169 return err;
0170 if (content_type)
0171 break;
0172 sub++;
0173 offset--;
0174 }
0175 tlm->control = content_type;
0176 }
0177 return sub;
0178 }
0179
0180 static void tls_decrypt_done(struct crypto_async_request *req, int err)
0181 {
0182 struct aead_request *aead_req = (struct aead_request *)req;
0183 struct scatterlist *sgout = aead_req->dst;
0184 struct scatterlist *sgin = aead_req->src;
0185 struct tls_sw_context_rx *ctx;
0186 struct tls_context *tls_ctx;
0187 struct scatterlist *sg;
0188 unsigned int pages;
0189 struct sock *sk;
0190
0191 sk = (struct sock *)req->data;
0192 tls_ctx = tls_get_ctx(sk);
0193 ctx = tls_sw_ctx_rx(tls_ctx);
0194
0195
0196 if (err) {
0197 if (err == -EBADMSG)
0198 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
0199 ctx->async_wait.err = err;
0200 tls_err_abort(sk, err);
0201 }
0202
0203
0204 if (sgout != sgin) {
0205
0206 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
0207 if (!sg)
0208 break;
0209 put_page(sg_page(sg));
0210 }
0211 }
0212
0213 kfree(aead_req);
0214
0215 spin_lock_bh(&ctx->decrypt_compl_lock);
0216 if (!atomic_dec_return(&ctx->decrypt_pending))
0217 complete(&ctx->async_wait.completion);
0218 spin_unlock_bh(&ctx->decrypt_compl_lock);
0219 }
0220
0221 static int tls_do_decryption(struct sock *sk,
0222 struct scatterlist *sgin,
0223 struct scatterlist *sgout,
0224 char *iv_recv,
0225 size_t data_len,
0226 struct aead_request *aead_req,
0227 struct tls_decrypt_arg *darg)
0228 {
0229 struct tls_context *tls_ctx = tls_get_ctx(sk);
0230 struct tls_prot_info *prot = &tls_ctx->prot_info;
0231 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
0232 int ret;
0233
0234 aead_request_set_tfm(aead_req, ctx->aead_recv);
0235 aead_request_set_ad(aead_req, prot->aad_size);
0236 aead_request_set_crypt(aead_req, sgin, sgout,
0237 data_len + prot->tag_size,
0238 (u8 *)iv_recv);
0239
0240 if (darg->async) {
0241 aead_request_set_callback(aead_req,
0242 CRYPTO_TFM_REQ_MAY_BACKLOG,
0243 tls_decrypt_done, sk);
0244 atomic_inc(&ctx->decrypt_pending);
0245 } else {
0246 aead_request_set_callback(aead_req,
0247 CRYPTO_TFM_REQ_MAY_BACKLOG,
0248 crypto_req_done, &ctx->async_wait);
0249 }
0250
0251 ret = crypto_aead_decrypt(aead_req);
0252 if (ret == -EINPROGRESS) {
0253 if (darg->async)
0254 return 0;
0255
0256 ret = crypto_wait_req(ret, &ctx->async_wait);
0257 }
0258 darg->async = false;
0259
0260 return ret;
0261 }
0262
0263 static void tls_trim_both_msgs(struct sock *sk, int target_size)
0264 {
0265 struct tls_context *tls_ctx = tls_get_ctx(sk);
0266 struct tls_prot_info *prot = &tls_ctx->prot_info;
0267 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0268 struct tls_rec *rec = ctx->open_rec;
0269
0270 sk_msg_trim(sk, &rec->msg_plaintext, target_size);
0271 if (target_size > 0)
0272 target_size += prot->overhead_size;
0273 sk_msg_trim(sk, &rec->msg_encrypted, target_size);
0274 }
0275
0276 static int tls_alloc_encrypted_msg(struct sock *sk, int len)
0277 {
0278 struct tls_context *tls_ctx = tls_get_ctx(sk);
0279 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0280 struct tls_rec *rec = ctx->open_rec;
0281 struct sk_msg *msg_en = &rec->msg_encrypted;
0282
0283 return sk_msg_alloc(sk, msg_en, len, 0);
0284 }
0285
0286 static int tls_clone_plaintext_msg(struct sock *sk, int required)
0287 {
0288 struct tls_context *tls_ctx = tls_get_ctx(sk);
0289 struct tls_prot_info *prot = &tls_ctx->prot_info;
0290 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0291 struct tls_rec *rec = ctx->open_rec;
0292 struct sk_msg *msg_pl = &rec->msg_plaintext;
0293 struct sk_msg *msg_en = &rec->msg_encrypted;
0294 int skip, len;
0295
0296
0297
0298
0299
0300 len = required - msg_pl->sg.size;
0301
0302
0303
0304
0305 skip = prot->prepend_size + msg_pl->sg.size;
0306
0307 return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
0308 }
0309
0310 static struct tls_rec *tls_get_rec(struct sock *sk)
0311 {
0312 struct tls_context *tls_ctx = tls_get_ctx(sk);
0313 struct tls_prot_info *prot = &tls_ctx->prot_info;
0314 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0315 struct sk_msg *msg_pl, *msg_en;
0316 struct tls_rec *rec;
0317 int mem_size;
0318
0319 mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
0320
0321 rec = kzalloc(mem_size, sk->sk_allocation);
0322 if (!rec)
0323 return NULL;
0324
0325 msg_pl = &rec->msg_plaintext;
0326 msg_en = &rec->msg_encrypted;
0327
0328 sk_msg_init(msg_pl);
0329 sk_msg_init(msg_en);
0330
0331 sg_init_table(rec->sg_aead_in, 2);
0332 sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
0333 sg_unmark_end(&rec->sg_aead_in[1]);
0334
0335 sg_init_table(rec->sg_aead_out, 2);
0336 sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
0337 sg_unmark_end(&rec->sg_aead_out[1]);
0338
0339 return rec;
0340 }
0341
0342 static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
0343 {
0344 sk_msg_free(sk, &rec->msg_encrypted);
0345 sk_msg_free(sk, &rec->msg_plaintext);
0346 kfree(rec);
0347 }
0348
0349 static void tls_free_open_rec(struct sock *sk)
0350 {
0351 struct tls_context *tls_ctx = tls_get_ctx(sk);
0352 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0353 struct tls_rec *rec = ctx->open_rec;
0354
0355 if (rec) {
0356 tls_free_rec(sk, rec);
0357 ctx->open_rec = NULL;
0358 }
0359 }
0360
0361 int tls_tx_records(struct sock *sk, int flags)
0362 {
0363 struct tls_context *tls_ctx = tls_get_ctx(sk);
0364 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0365 struct tls_rec *rec, *tmp;
0366 struct sk_msg *msg_en;
0367 int tx_flags, rc = 0;
0368
0369 if (tls_is_partially_sent_record(tls_ctx)) {
0370 rec = list_first_entry(&ctx->tx_list,
0371 struct tls_rec, list);
0372
0373 if (flags == -1)
0374 tx_flags = rec->tx_flags;
0375 else
0376 tx_flags = flags;
0377
0378 rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
0379 if (rc)
0380 goto tx_err;
0381
0382
0383
0384
0385 list_del(&rec->list);
0386 sk_msg_free(sk, &rec->msg_plaintext);
0387 kfree(rec);
0388 }
0389
0390
0391 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
0392 if (READ_ONCE(rec->tx_ready)) {
0393 if (flags == -1)
0394 tx_flags = rec->tx_flags;
0395 else
0396 tx_flags = flags;
0397
0398 msg_en = &rec->msg_encrypted;
0399 rc = tls_push_sg(sk, tls_ctx,
0400 &msg_en->sg.data[msg_en->sg.curr],
0401 0, tx_flags);
0402 if (rc)
0403 goto tx_err;
0404
0405 list_del(&rec->list);
0406 sk_msg_free(sk, &rec->msg_plaintext);
0407 kfree(rec);
0408 } else {
0409 break;
0410 }
0411 }
0412
0413 tx_err:
0414 if (rc < 0 && rc != -EAGAIN)
0415 tls_err_abort(sk, -EBADMSG);
0416
0417 return rc;
0418 }
0419
0420 static void tls_encrypt_done(struct crypto_async_request *req, int err)
0421 {
0422 struct aead_request *aead_req = (struct aead_request *)req;
0423 struct sock *sk = req->data;
0424 struct tls_context *tls_ctx = tls_get_ctx(sk);
0425 struct tls_prot_info *prot = &tls_ctx->prot_info;
0426 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0427 struct scatterlist *sge;
0428 struct sk_msg *msg_en;
0429 struct tls_rec *rec;
0430 bool ready = false;
0431 int pending;
0432
0433 rec = container_of(aead_req, struct tls_rec, aead_req);
0434 msg_en = &rec->msg_encrypted;
0435
0436 sge = sk_msg_elem(msg_en, msg_en->sg.curr);
0437 sge->offset -= prot->prepend_size;
0438 sge->length += prot->prepend_size;
0439
0440
0441 if (err || sk->sk_err) {
0442 rec = NULL;
0443
0444
0445 if (sk->sk_err) {
0446 ctx->async_wait.err = -sk->sk_err;
0447 } else {
0448 ctx->async_wait.err = err;
0449 tls_err_abort(sk, err);
0450 }
0451 }
0452
0453 if (rec) {
0454 struct tls_rec *first_rec;
0455
0456
0457 smp_store_mb(rec->tx_ready, true);
0458
0459
0460 first_rec = list_first_entry(&ctx->tx_list,
0461 struct tls_rec, list);
0462 if (rec == first_rec)
0463 ready = true;
0464 }
0465
0466 spin_lock_bh(&ctx->encrypt_compl_lock);
0467 pending = atomic_dec_return(&ctx->encrypt_pending);
0468
0469 if (!pending && ctx->async_notify)
0470 complete(&ctx->async_wait.completion);
0471 spin_unlock_bh(&ctx->encrypt_compl_lock);
0472
0473 if (!ready)
0474 return;
0475
0476
0477 if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
0478 schedule_delayed_work(&ctx->tx_work.work, 1);
0479 }
0480
0481 static int tls_do_encryption(struct sock *sk,
0482 struct tls_context *tls_ctx,
0483 struct tls_sw_context_tx *ctx,
0484 struct aead_request *aead_req,
0485 size_t data_len, u32 start)
0486 {
0487 struct tls_prot_info *prot = &tls_ctx->prot_info;
0488 struct tls_rec *rec = ctx->open_rec;
0489 struct sk_msg *msg_en = &rec->msg_encrypted;
0490 struct scatterlist *sge = sk_msg_elem(msg_en, start);
0491 int rc, iv_offset = 0;
0492
0493
0494 switch (prot->cipher_type) {
0495 case TLS_CIPHER_AES_CCM_128:
0496 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
0497 iv_offset = 1;
0498 break;
0499 case TLS_CIPHER_SM4_CCM:
0500 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
0501 iv_offset = 1;
0502 break;
0503 }
0504
0505 memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
0506 prot->iv_size + prot->salt_size);
0507
0508 tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
0509 tls_ctx->tx.rec_seq);
0510
0511 sge->offset += prot->prepend_size;
0512 sge->length -= prot->prepend_size;
0513
0514 msg_en->sg.curr = start;
0515
0516 aead_request_set_tfm(aead_req, ctx->aead_send);
0517 aead_request_set_ad(aead_req, prot->aad_size);
0518 aead_request_set_crypt(aead_req, rec->sg_aead_in,
0519 rec->sg_aead_out,
0520 data_len, rec->iv_data);
0521
0522 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
0523 tls_encrypt_done, sk);
0524
0525
0526 list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
0527 atomic_inc(&ctx->encrypt_pending);
0528
0529 rc = crypto_aead_encrypt(aead_req);
0530 if (!rc || rc != -EINPROGRESS) {
0531 atomic_dec(&ctx->encrypt_pending);
0532 sge->offset -= prot->prepend_size;
0533 sge->length += prot->prepend_size;
0534 }
0535
0536 if (!rc) {
0537 WRITE_ONCE(rec->tx_ready, true);
0538 } else if (rc != -EINPROGRESS) {
0539 list_del(&rec->list);
0540 return rc;
0541 }
0542
0543
0544 ctx->open_rec = NULL;
0545 tls_advance_record_sn(sk, prot, &tls_ctx->tx);
0546 return rc;
0547 }
0548
0549 static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
0550 struct tls_rec **to, struct sk_msg *msg_opl,
0551 struct sk_msg *msg_oen, u32 split_point,
0552 u32 tx_overhead_size, u32 *orig_end)
0553 {
0554 u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
0555 struct scatterlist *sge, *osge, *nsge;
0556 u32 orig_size = msg_opl->sg.size;
0557 struct scatterlist tmp = { };
0558 struct sk_msg *msg_npl;
0559 struct tls_rec *new;
0560 int ret;
0561
0562 new = tls_get_rec(sk);
0563 if (!new)
0564 return -ENOMEM;
0565 ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
0566 tx_overhead_size, 0);
0567 if (ret < 0) {
0568 tls_free_rec(sk, new);
0569 return ret;
0570 }
0571
0572 *orig_end = msg_opl->sg.end;
0573 i = msg_opl->sg.start;
0574 sge = sk_msg_elem(msg_opl, i);
0575 while (apply && sge->length) {
0576 if (sge->length > apply) {
0577 u32 len = sge->length - apply;
0578
0579 get_page(sg_page(sge));
0580 sg_set_page(&tmp, sg_page(sge), len,
0581 sge->offset + apply);
0582 sge->length = apply;
0583 bytes += apply;
0584 apply = 0;
0585 } else {
0586 apply -= sge->length;
0587 bytes += sge->length;
0588 }
0589
0590 sk_msg_iter_var_next(i);
0591 if (i == msg_opl->sg.end)
0592 break;
0593 sge = sk_msg_elem(msg_opl, i);
0594 }
0595
0596 msg_opl->sg.end = i;
0597 msg_opl->sg.curr = i;
0598 msg_opl->sg.copybreak = 0;
0599 msg_opl->apply_bytes = 0;
0600 msg_opl->sg.size = bytes;
0601
0602 msg_npl = &new->msg_plaintext;
0603 msg_npl->apply_bytes = apply;
0604 msg_npl->sg.size = orig_size - bytes;
0605
0606 j = msg_npl->sg.start;
0607 nsge = sk_msg_elem(msg_npl, j);
0608 if (tmp.length) {
0609 memcpy(nsge, &tmp, sizeof(*nsge));
0610 sk_msg_iter_var_next(j);
0611 nsge = sk_msg_elem(msg_npl, j);
0612 }
0613
0614 osge = sk_msg_elem(msg_opl, i);
0615 while (osge->length) {
0616 memcpy(nsge, osge, sizeof(*nsge));
0617 sg_unmark_end(nsge);
0618 sk_msg_iter_var_next(i);
0619 sk_msg_iter_var_next(j);
0620 if (i == *orig_end)
0621 break;
0622 osge = sk_msg_elem(msg_opl, i);
0623 nsge = sk_msg_elem(msg_npl, j);
0624 }
0625
0626 msg_npl->sg.end = j;
0627 msg_npl->sg.curr = j;
0628 msg_npl->sg.copybreak = 0;
0629
0630 *to = new;
0631 return 0;
0632 }
0633
0634 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
0635 struct tls_rec *from, u32 orig_end)
0636 {
0637 struct sk_msg *msg_npl = &from->msg_plaintext;
0638 struct sk_msg *msg_opl = &to->msg_plaintext;
0639 struct scatterlist *osge, *nsge;
0640 u32 i, j;
0641
0642 i = msg_opl->sg.end;
0643 sk_msg_iter_var_prev(i);
0644 j = msg_npl->sg.start;
0645
0646 osge = sk_msg_elem(msg_opl, i);
0647 nsge = sk_msg_elem(msg_npl, j);
0648
0649 if (sg_page(osge) == sg_page(nsge) &&
0650 osge->offset + osge->length == nsge->offset) {
0651 osge->length += nsge->length;
0652 put_page(sg_page(nsge));
0653 }
0654
0655 msg_opl->sg.end = orig_end;
0656 msg_opl->sg.curr = orig_end;
0657 msg_opl->sg.copybreak = 0;
0658 msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
0659 msg_opl->sg.size += msg_npl->sg.size;
0660
0661 sk_msg_free(sk, &to->msg_encrypted);
0662 sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
0663
0664 kfree(from);
0665 }
0666
0667 static int tls_push_record(struct sock *sk, int flags,
0668 unsigned char record_type)
0669 {
0670 struct tls_context *tls_ctx = tls_get_ctx(sk);
0671 struct tls_prot_info *prot = &tls_ctx->prot_info;
0672 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0673 struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
0674 u32 i, split_point, orig_end;
0675 struct sk_msg *msg_pl, *msg_en;
0676 struct aead_request *req;
0677 bool split;
0678 int rc;
0679
0680 if (!rec)
0681 return 0;
0682
0683 msg_pl = &rec->msg_plaintext;
0684 msg_en = &rec->msg_encrypted;
0685
0686 split_point = msg_pl->apply_bytes;
0687 split = split_point && split_point < msg_pl->sg.size;
0688 if (unlikely((!split &&
0689 msg_pl->sg.size +
0690 prot->overhead_size > msg_en->sg.size) ||
0691 (split &&
0692 split_point +
0693 prot->overhead_size > msg_en->sg.size))) {
0694 split = true;
0695 split_point = msg_en->sg.size;
0696 }
0697 if (split) {
0698 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
0699 split_point, prot->overhead_size,
0700 &orig_end);
0701 if (rc < 0)
0702 return rc;
0703
0704
0705
0706
0707
0708 if (!msg_pl->sg.size) {
0709 tls_merge_open_record(sk, rec, tmp, orig_end);
0710 msg_pl = &rec->msg_plaintext;
0711 msg_en = &rec->msg_encrypted;
0712 split = false;
0713 }
0714 sk_msg_trim(sk, msg_en, msg_pl->sg.size +
0715 prot->overhead_size);
0716 }
0717
0718 rec->tx_flags = flags;
0719 req = &rec->aead_req;
0720
0721 i = msg_pl->sg.end;
0722 sk_msg_iter_var_prev(i);
0723
0724 rec->content_type = record_type;
0725 if (prot->version == TLS_1_3_VERSION) {
0726
0727 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
0728 sg_mark_end(&rec->sg_content_type);
0729 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
0730 &rec->sg_content_type);
0731 } else {
0732 sg_mark_end(sk_msg_elem(msg_pl, i));
0733 }
0734
0735 if (msg_pl->sg.end < msg_pl->sg.start) {
0736 sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
0737 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
0738 msg_pl->sg.data);
0739 }
0740
0741 i = msg_pl->sg.start;
0742 sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
0743
0744 i = msg_en->sg.end;
0745 sk_msg_iter_var_prev(i);
0746 sg_mark_end(sk_msg_elem(msg_en, i));
0747
0748 i = msg_en->sg.start;
0749 sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
0750
0751 tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
0752 tls_ctx->tx.rec_seq, record_type, prot);
0753
0754 tls_fill_prepend(tls_ctx,
0755 page_address(sg_page(&msg_en->sg.data[i])) +
0756 msg_en->sg.data[i].offset,
0757 msg_pl->sg.size + prot->tail_size,
0758 record_type);
0759
0760 tls_ctx->pending_open_record_frags = false;
0761
0762 rc = tls_do_encryption(sk, tls_ctx, ctx, req,
0763 msg_pl->sg.size + prot->tail_size, i);
0764 if (rc < 0) {
0765 if (rc != -EINPROGRESS) {
0766 tls_err_abort(sk, -EBADMSG);
0767 if (split) {
0768 tls_ctx->pending_open_record_frags = true;
0769 tls_merge_open_record(sk, rec, tmp, orig_end);
0770 }
0771 }
0772 ctx->async_capable = 1;
0773 return rc;
0774 } else if (split) {
0775 msg_pl = &tmp->msg_plaintext;
0776 msg_en = &tmp->msg_encrypted;
0777 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
0778 tls_ctx->pending_open_record_frags = true;
0779 ctx->open_rec = tmp;
0780 }
0781
0782 return tls_tx_records(sk, flags);
0783 }
0784
0785 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
0786 bool full_record, u8 record_type,
0787 ssize_t *copied, int flags)
0788 {
0789 struct tls_context *tls_ctx = tls_get_ctx(sk);
0790 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0791 struct sk_msg msg_redir = { };
0792 struct sk_psock *psock;
0793 struct sock *sk_redir;
0794 struct tls_rec *rec;
0795 bool enospc, policy;
0796 int err = 0, send;
0797 u32 delta = 0;
0798
0799 policy = !(flags & MSG_SENDPAGE_NOPOLICY);
0800 psock = sk_psock_get(sk);
0801 if (!psock || !policy) {
0802 err = tls_push_record(sk, flags, record_type);
0803 if (err && sk->sk_err == EBADMSG) {
0804 *copied -= sk_msg_free(sk, msg);
0805 tls_free_open_rec(sk);
0806 err = -sk->sk_err;
0807 }
0808 if (psock)
0809 sk_psock_put(sk, psock);
0810 return err;
0811 }
0812 more_data:
0813 enospc = sk_msg_full(msg);
0814 if (psock->eval == __SK_NONE) {
0815 delta = msg->sg.size;
0816 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
0817 delta -= msg->sg.size;
0818 }
0819 if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
0820 !enospc && !full_record) {
0821 err = -ENOSPC;
0822 goto out_err;
0823 }
0824 msg->cork_bytes = 0;
0825 send = msg->sg.size;
0826 if (msg->apply_bytes && msg->apply_bytes < send)
0827 send = msg->apply_bytes;
0828
0829 switch (psock->eval) {
0830 case __SK_PASS:
0831 err = tls_push_record(sk, flags, record_type);
0832 if (err && sk->sk_err == EBADMSG) {
0833 *copied -= sk_msg_free(sk, msg);
0834 tls_free_open_rec(sk);
0835 err = -sk->sk_err;
0836 goto out_err;
0837 }
0838 break;
0839 case __SK_REDIRECT:
0840 sk_redir = psock->sk_redir;
0841 memcpy(&msg_redir, msg, sizeof(*msg));
0842 if (msg->apply_bytes < send)
0843 msg->apply_bytes = 0;
0844 else
0845 msg->apply_bytes -= send;
0846 sk_msg_return_zero(sk, msg, send);
0847 msg->sg.size -= send;
0848 release_sock(sk);
0849 err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
0850 lock_sock(sk);
0851 if (err < 0) {
0852 *copied -= sk_msg_free_nocharge(sk, &msg_redir);
0853 msg->sg.size = 0;
0854 }
0855 if (msg->sg.size == 0)
0856 tls_free_open_rec(sk);
0857 break;
0858 case __SK_DROP:
0859 default:
0860 sk_msg_free_partial(sk, msg, send);
0861 if (msg->apply_bytes < send)
0862 msg->apply_bytes = 0;
0863 else
0864 msg->apply_bytes -= send;
0865 if (msg->sg.size == 0)
0866 tls_free_open_rec(sk);
0867 *copied -= (send + delta);
0868 err = -EACCES;
0869 }
0870
0871 if (likely(!err)) {
0872 bool reset_eval = !ctx->open_rec;
0873
0874 rec = ctx->open_rec;
0875 if (rec) {
0876 msg = &rec->msg_plaintext;
0877 if (!msg->apply_bytes)
0878 reset_eval = true;
0879 }
0880 if (reset_eval) {
0881 psock->eval = __SK_NONE;
0882 if (psock->sk_redir) {
0883 sock_put(psock->sk_redir);
0884 psock->sk_redir = NULL;
0885 }
0886 }
0887 if (rec)
0888 goto more_data;
0889 }
0890 out_err:
0891 sk_psock_put(sk, psock);
0892 return err;
0893 }
0894
0895 static int tls_sw_push_pending_record(struct sock *sk, int flags)
0896 {
0897 struct tls_context *tls_ctx = tls_get_ctx(sk);
0898 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0899 struct tls_rec *rec = ctx->open_rec;
0900 struct sk_msg *msg_pl;
0901 size_t copied;
0902
0903 if (!rec)
0904 return 0;
0905
0906 msg_pl = &rec->msg_plaintext;
0907 copied = msg_pl->sg.size;
0908 if (!copied)
0909 return 0;
0910
0911 return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
0912 &copied, flags);
0913 }
0914
0915 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
0916 {
0917 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
0918 struct tls_context *tls_ctx = tls_get_ctx(sk);
0919 struct tls_prot_info *prot = &tls_ctx->prot_info;
0920 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
0921 bool async_capable = ctx->async_capable;
0922 unsigned char record_type = TLS_RECORD_TYPE_DATA;
0923 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
0924 bool eor = !(msg->msg_flags & MSG_MORE);
0925 size_t try_to_copy;
0926 ssize_t copied = 0;
0927 struct sk_msg *msg_pl, *msg_en;
0928 struct tls_rec *rec;
0929 int required_size;
0930 int num_async = 0;
0931 bool full_record;
0932 int record_room;
0933 int num_zc = 0;
0934 int orig_size;
0935 int ret = 0;
0936 int pending;
0937
0938 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
0939 MSG_CMSG_COMPAT))
0940 return -EOPNOTSUPP;
0941
0942 mutex_lock(&tls_ctx->tx_lock);
0943 lock_sock(sk);
0944
0945 if (unlikely(msg->msg_controllen)) {
0946 ret = tls_process_cmsg(sk, msg, &record_type);
0947 if (ret) {
0948 if (ret == -EINPROGRESS)
0949 num_async++;
0950 else if (ret != -EAGAIN)
0951 goto send_end;
0952 }
0953 }
0954
0955 while (msg_data_left(msg)) {
0956 if (sk->sk_err) {
0957 ret = -sk->sk_err;
0958 goto send_end;
0959 }
0960
0961 if (ctx->open_rec)
0962 rec = ctx->open_rec;
0963 else
0964 rec = ctx->open_rec = tls_get_rec(sk);
0965 if (!rec) {
0966 ret = -ENOMEM;
0967 goto send_end;
0968 }
0969
0970 msg_pl = &rec->msg_plaintext;
0971 msg_en = &rec->msg_encrypted;
0972
0973 orig_size = msg_pl->sg.size;
0974 full_record = false;
0975 try_to_copy = msg_data_left(msg);
0976 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
0977 if (try_to_copy >= record_room) {
0978 try_to_copy = record_room;
0979 full_record = true;
0980 }
0981
0982 required_size = msg_pl->sg.size + try_to_copy +
0983 prot->overhead_size;
0984
0985 if (!sk_stream_memory_free(sk))
0986 goto wait_for_sndbuf;
0987
0988 alloc_encrypted:
0989 ret = tls_alloc_encrypted_msg(sk, required_size);
0990 if (ret) {
0991 if (ret != -ENOSPC)
0992 goto wait_for_memory;
0993
0994
0995
0996
0997
0998 try_to_copy -= required_size - msg_en->sg.size;
0999 full_record = true;
1000 }
1001
1002 if (!is_kvec && (full_record || eor) && !async_capable) {
1003 u32 first = msg_pl->sg.end;
1004
1005 ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1006 msg_pl, try_to_copy);
1007 if (ret)
1008 goto fallback_to_reg_send;
1009
1010 num_zc++;
1011 copied += try_to_copy;
1012
1013 sk_msg_sg_copy_set(msg_pl, first);
1014 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1015 record_type, &copied,
1016 msg->msg_flags);
1017 if (ret) {
1018 if (ret == -EINPROGRESS)
1019 num_async++;
1020 else if (ret == -ENOMEM)
1021 goto wait_for_memory;
1022 else if (ctx->open_rec && ret == -ENOSPC)
1023 goto rollback_iter;
1024 else if (ret != -EAGAIN)
1025 goto send_end;
1026 }
1027 continue;
1028 rollback_iter:
1029 copied -= try_to_copy;
1030 sk_msg_sg_copy_clear(msg_pl, first);
1031 iov_iter_revert(&msg->msg_iter,
1032 msg_pl->sg.size - orig_size);
1033 fallback_to_reg_send:
1034 sk_msg_trim(sk, msg_pl, orig_size);
1035 }
1036
1037 required_size = msg_pl->sg.size + try_to_copy;
1038
1039 ret = tls_clone_plaintext_msg(sk, required_size);
1040 if (ret) {
1041 if (ret != -ENOSPC)
1042 goto send_end;
1043
1044
1045
1046
1047
1048 try_to_copy -= required_size - msg_pl->sg.size;
1049 full_record = true;
1050 sk_msg_trim(sk, msg_en,
1051 msg_pl->sg.size + prot->overhead_size);
1052 }
1053
1054 if (try_to_copy) {
1055 ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1056 msg_pl, try_to_copy);
1057 if (ret < 0)
1058 goto trim_sgl;
1059 }
1060
1061
1062
1063
1064 tls_ctx->pending_open_record_frags = true;
1065 copied += try_to_copy;
1066 if (full_record || eor) {
1067 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1068 record_type, &copied,
1069 msg->msg_flags);
1070 if (ret) {
1071 if (ret == -EINPROGRESS)
1072 num_async++;
1073 else if (ret == -ENOMEM)
1074 goto wait_for_memory;
1075 else if (ret != -EAGAIN) {
1076 if (ret == -ENOSPC)
1077 ret = 0;
1078 goto send_end;
1079 }
1080 }
1081 }
1082
1083 continue;
1084
1085 wait_for_sndbuf:
1086 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1087 wait_for_memory:
1088 ret = sk_stream_wait_memory(sk, &timeo);
1089 if (ret) {
1090 trim_sgl:
1091 if (ctx->open_rec)
1092 tls_trim_both_msgs(sk, orig_size);
1093 goto send_end;
1094 }
1095
1096 if (ctx->open_rec && msg_en->sg.size < required_size)
1097 goto alloc_encrypted;
1098 }
1099
1100 if (!num_async) {
1101 goto send_end;
1102 } else if (num_zc) {
1103
1104 spin_lock_bh(&ctx->encrypt_compl_lock);
1105 ctx->async_notify = true;
1106
1107 pending = atomic_read(&ctx->encrypt_pending);
1108 spin_unlock_bh(&ctx->encrypt_compl_lock);
1109 if (pending)
1110 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1111 else
1112 reinit_completion(&ctx->async_wait.completion);
1113
1114
1115
1116
1117 WRITE_ONCE(ctx->async_notify, false);
1118
1119 if (ctx->async_wait.err) {
1120 ret = ctx->async_wait.err;
1121 copied = 0;
1122 }
1123 }
1124
1125
1126 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1127 cancel_delayed_work(&ctx->tx_work.work);
1128 tls_tx_records(sk, msg->msg_flags);
1129 }
1130
1131 send_end:
1132 ret = sk_stream_error(sk, msg->msg_flags, ret);
1133
1134 release_sock(sk);
1135 mutex_unlock(&tls_ctx->tx_lock);
1136 return copied > 0 ? copied : ret;
1137 }
1138
1139 static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1140 int offset, size_t size, int flags)
1141 {
1142 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1143 struct tls_context *tls_ctx = tls_get_ctx(sk);
1144 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1145 struct tls_prot_info *prot = &tls_ctx->prot_info;
1146 unsigned char record_type = TLS_RECORD_TYPE_DATA;
1147 struct sk_msg *msg_pl;
1148 struct tls_rec *rec;
1149 int num_async = 0;
1150 ssize_t copied = 0;
1151 bool full_record;
1152 int record_room;
1153 int ret = 0;
1154 bool eor;
1155
1156 eor = !(flags & MSG_SENDPAGE_NOTLAST);
1157 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1158
1159
1160 while (size > 0) {
1161 size_t copy, required_size;
1162
1163 if (sk->sk_err) {
1164 ret = -sk->sk_err;
1165 goto sendpage_end;
1166 }
1167
1168 if (ctx->open_rec)
1169 rec = ctx->open_rec;
1170 else
1171 rec = ctx->open_rec = tls_get_rec(sk);
1172 if (!rec) {
1173 ret = -ENOMEM;
1174 goto sendpage_end;
1175 }
1176
1177 msg_pl = &rec->msg_plaintext;
1178
1179 full_record = false;
1180 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1181 copy = size;
1182 if (copy >= record_room) {
1183 copy = record_room;
1184 full_record = true;
1185 }
1186
1187 required_size = msg_pl->sg.size + copy + prot->overhead_size;
1188
1189 if (!sk_stream_memory_free(sk))
1190 goto wait_for_sndbuf;
1191 alloc_payload:
1192 ret = tls_alloc_encrypted_msg(sk, required_size);
1193 if (ret) {
1194 if (ret != -ENOSPC)
1195 goto wait_for_memory;
1196
1197
1198
1199
1200
1201 copy -= required_size - msg_pl->sg.size;
1202 full_record = true;
1203 }
1204
1205 sk_msg_page_add(msg_pl, page, copy, offset);
1206 sk_mem_charge(sk, copy);
1207
1208 offset += copy;
1209 size -= copy;
1210 copied += copy;
1211
1212 tls_ctx->pending_open_record_frags = true;
1213 if (full_record || eor || sk_msg_full(msg_pl)) {
1214 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1215 record_type, &copied, flags);
1216 if (ret) {
1217 if (ret == -EINPROGRESS)
1218 num_async++;
1219 else if (ret == -ENOMEM)
1220 goto wait_for_memory;
1221 else if (ret != -EAGAIN) {
1222 if (ret == -ENOSPC)
1223 ret = 0;
1224 goto sendpage_end;
1225 }
1226 }
1227 }
1228 continue;
1229 wait_for_sndbuf:
1230 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1231 wait_for_memory:
1232 ret = sk_stream_wait_memory(sk, &timeo);
1233 if (ret) {
1234 if (ctx->open_rec)
1235 tls_trim_both_msgs(sk, msg_pl->sg.size);
1236 goto sendpage_end;
1237 }
1238
1239 if (ctx->open_rec)
1240 goto alloc_payload;
1241 }
1242
1243 if (num_async) {
1244
1245 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1246 cancel_delayed_work(&ctx->tx_work.work);
1247 tls_tx_records(sk, flags);
1248 }
1249 }
1250 sendpage_end:
1251 ret = sk_stream_error(sk, flags, ret);
1252 return copied > 0 ? copied : ret;
1253 }
1254
1255 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1256 int offset, size_t size, int flags)
1257 {
1258 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1259 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1260 MSG_NO_SHARED_FRAGS))
1261 return -EOPNOTSUPP;
1262
1263 return tls_sw_do_sendpage(sk, page, offset, size, flags);
1264 }
1265
1266 int tls_sw_sendpage(struct sock *sk, struct page *page,
1267 int offset, size_t size, int flags)
1268 {
1269 struct tls_context *tls_ctx = tls_get_ctx(sk);
1270 int ret;
1271
1272 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1273 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1274 return -EOPNOTSUPP;
1275
1276 mutex_lock(&tls_ctx->tx_lock);
1277 lock_sock(sk);
1278 ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1279 release_sock(sk);
1280 mutex_unlock(&tls_ctx->tx_lock);
1281 return ret;
1282 }
1283
1284 static int
1285 tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
1286 bool released)
1287 {
1288 struct tls_context *tls_ctx = tls_get_ctx(sk);
1289 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1290 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1291 long timeo;
1292
1293 timeo = sock_rcvtimeo(sk, nonblock);
1294
1295 while (!tls_strp_msg_ready(ctx)) {
1296 if (!sk_psock_queue_empty(psock))
1297 return 0;
1298
1299 if (sk->sk_err)
1300 return sock_error(sk);
1301
1302 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1303 tls_strp_check_rcv(&ctx->strp);
1304 if (tls_strp_msg_ready(ctx))
1305 break;
1306 }
1307
1308 if (sk->sk_shutdown & RCV_SHUTDOWN)
1309 return 0;
1310
1311 if (sock_flag(sk, SOCK_DONE))
1312 return 0;
1313
1314 if (!timeo)
1315 return -EAGAIN;
1316
1317 released = true;
1318 add_wait_queue(sk_sleep(sk), &wait);
1319 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1320 sk_wait_event(sk, &timeo,
1321 tls_strp_msg_ready(ctx) ||
1322 !sk_psock_queue_empty(psock),
1323 &wait);
1324 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1325 remove_wait_queue(sk_sleep(sk), &wait);
1326
1327
1328 if (signal_pending(current))
1329 return sock_intr_errno(timeo);
1330 }
1331
1332 tls_strp_msg_load(&ctx->strp, released);
1333
1334 return 1;
1335 }
1336
1337 static int tls_setup_from_iter(struct iov_iter *from,
1338 int length, int *pages_used,
1339 struct scatterlist *to,
1340 int to_max_pages)
1341 {
1342 int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1343 struct page *pages[MAX_SKB_FRAGS];
1344 unsigned int size = 0;
1345 ssize_t copied, use;
1346 size_t offset;
1347
1348 while (length > 0) {
1349 i = 0;
1350 maxpages = to_max_pages - num_elem;
1351 if (maxpages == 0) {
1352 rc = -EFAULT;
1353 goto out;
1354 }
1355 copied = iov_iter_get_pages2(from, pages,
1356 length,
1357 maxpages, &offset);
1358 if (copied <= 0) {
1359 rc = -EFAULT;
1360 goto out;
1361 }
1362
1363 length -= copied;
1364 size += copied;
1365 while (copied) {
1366 use = min_t(int, copied, PAGE_SIZE - offset);
1367
1368 sg_set_page(&to[num_elem],
1369 pages[i], use, offset);
1370 sg_unmark_end(&to[num_elem]);
1371
1372
1373 offset = 0;
1374 copied -= use;
1375
1376 i++;
1377 num_elem++;
1378 }
1379 }
1380
1381 if (num_elem > *pages_used)
1382 sg_mark_end(&to[num_elem - 1]);
1383 out:
1384 if (rc)
1385 iov_iter_revert(from, size);
1386 *pages_used = num_elem;
1387
1388 return rc;
1389 }
1390
1391 static struct sk_buff *
1392 tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
1393 unsigned int full_len)
1394 {
1395 struct strp_msg *clr_rxm;
1396 struct sk_buff *clr_skb;
1397 int err;
1398
1399 clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
1400 &err, sk->sk_allocation);
1401 if (!clr_skb)
1402 return NULL;
1403
1404 skb_copy_header(clr_skb, skb);
1405 clr_skb->len = full_len;
1406 clr_skb->data_len = full_len;
1407
1408 clr_rxm = strp_msg(clr_skb);
1409 clr_rxm->offset = 0;
1410
1411 return clr_skb;
1412 }
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434 static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
1435 struct scatterlist *out_sg,
1436 struct tls_decrypt_arg *darg)
1437 {
1438 struct tls_context *tls_ctx = tls_get_ctx(sk);
1439 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1440 struct tls_prot_info *prot = &tls_ctx->prot_info;
1441 int n_sgin, n_sgout, aead_size, err, pages = 0;
1442 struct sk_buff *skb = tls_strp_msg(ctx);
1443 const struct strp_msg *rxm = strp_msg(skb);
1444 const struct tls_msg *tlm = tls_msg(skb);
1445 struct aead_request *aead_req;
1446 struct scatterlist *sgin = NULL;
1447 struct scatterlist *sgout = NULL;
1448 const int data_len = rxm->full_len - prot->overhead_size;
1449 int tail_pages = !!prot->tail_size;
1450 struct tls_decrypt_ctx *dctx;
1451 struct sk_buff *clear_skb;
1452 int iv_offset = 0;
1453 u8 *mem;
1454
1455 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1456 rxm->full_len - prot->prepend_size);
1457 if (n_sgin < 1)
1458 return n_sgin ?: -EBADMSG;
1459
1460 if (darg->zc && (out_iov || out_sg)) {
1461 clear_skb = NULL;
1462
1463 if (out_iov)
1464 n_sgout = 1 + tail_pages +
1465 iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1466 else
1467 n_sgout = sg_nents(out_sg);
1468 } else {
1469 darg->zc = false;
1470
1471 clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
1472 if (!clear_skb)
1473 return -ENOMEM;
1474
1475 n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
1476 }
1477
1478
1479 n_sgin = n_sgin + 1;
1480
1481
1482
1483
1484
1485 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1486 mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
1487 sk->sk_allocation);
1488 if (!mem) {
1489 err = -ENOMEM;
1490 goto exit_free_skb;
1491 }
1492
1493
1494 aead_req = (struct aead_request *)mem;
1495 dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
1496 sgin = &dctx->sg[0];
1497 sgout = &dctx->sg[n_sgin];
1498
1499
1500 switch (prot->cipher_type) {
1501 case TLS_CIPHER_AES_CCM_128:
1502 dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1503 iv_offset = 1;
1504 break;
1505 case TLS_CIPHER_SM4_CCM:
1506 dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
1507 iv_offset = 1;
1508 break;
1509 }
1510
1511
1512 if (prot->version == TLS_1_3_VERSION ||
1513 prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1514 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
1515 prot->iv_size + prot->salt_size);
1516 } else {
1517 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1518 &dctx->iv[iv_offset] + prot->salt_size,
1519 prot->iv_size);
1520 if (err < 0)
1521 goto exit_free;
1522 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
1523 }
1524 tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
1525
1526
1527 tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
1528 prot->tail_size,
1529 tls_ctx->rx.rec_seq, tlm->control, prot);
1530
1531
1532 sg_init_table(sgin, n_sgin);
1533 sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
1534 err = skb_to_sgvec(skb, &sgin[1],
1535 rxm->offset + prot->prepend_size,
1536 rxm->full_len - prot->prepend_size);
1537 if (err < 0)
1538 goto exit_free;
1539
1540 if (clear_skb) {
1541 sg_init_table(sgout, n_sgout);
1542 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1543
1544 err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
1545 data_len + prot->tail_size);
1546 if (err < 0)
1547 goto exit_free;
1548 } else if (out_iov) {
1549 sg_init_table(sgout, n_sgout);
1550 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1551
1552 err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
1553 (n_sgout - 1 - tail_pages));
1554 if (err < 0)
1555 goto exit_free_pages;
1556
1557 if (prot->tail_size) {
1558 sg_unmark_end(&sgout[pages]);
1559 sg_set_buf(&sgout[pages + 1], &dctx->tail,
1560 prot->tail_size);
1561 sg_mark_end(&sgout[pages + 1]);
1562 }
1563 } else if (out_sg) {
1564 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1565 }
1566
1567
1568 err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
1569 data_len + prot->tail_size, aead_req, darg);
1570 if (err)
1571 goto exit_free_pages;
1572
1573 darg->skb = clear_skb ?: tls_strp_msg(ctx);
1574 clear_skb = NULL;
1575
1576 if (unlikely(darg->async)) {
1577 err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
1578 if (err)
1579 __skb_queue_tail(&ctx->async_hold, darg->skb);
1580 return err;
1581 }
1582
1583 if (prot->tail_size)
1584 darg->tail = dctx->tail;
1585
1586 exit_free_pages:
1587
1588 for (; pages > 0; pages--)
1589 put_page(sg_page(&sgout[pages]));
1590 exit_free:
1591 kfree(mem);
1592 exit_free_skb:
1593 consume_skb(clear_skb);
1594 return err;
1595 }
1596
1597 static int
1598 tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
1599 struct msghdr *msg, struct tls_decrypt_arg *darg)
1600 {
1601 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1602 struct tls_prot_info *prot = &tls_ctx->prot_info;
1603 struct strp_msg *rxm;
1604 int pad, err;
1605
1606 err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
1607 if (err < 0) {
1608 if (err == -EBADMSG)
1609 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
1610 return err;
1611 }
1612
1613
1614
1615 if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
1616 darg->tail != TLS_RECORD_TYPE_DATA)) {
1617 darg->zc = false;
1618 if (!darg->tail)
1619 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
1620 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
1621 return tls_decrypt_sw(sk, tls_ctx, msg, darg);
1622 }
1623
1624 pad = tls_padding_length(prot, darg->skb, darg);
1625 if (pad < 0) {
1626 if (darg->skb != tls_strp_msg(ctx))
1627 consume_skb(darg->skb);
1628 return pad;
1629 }
1630
1631 rxm = strp_msg(darg->skb);
1632 rxm->full_len -= pad;
1633
1634 return 0;
1635 }
1636
1637 static int
1638 tls_decrypt_device(struct sock *sk, struct msghdr *msg,
1639 struct tls_context *tls_ctx, struct tls_decrypt_arg *darg)
1640 {
1641 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1642 struct tls_prot_info *prot = &tls_ctx->prot_info;
1643 struct strp_msg *rxm;
1644 int pad, err;
1645
1646 if (tls_ctx->rx_conf != TLS_HW)
1647 return 0;
1648
1649 err = tls_device_decrypted(sk, tls_ctx);
1650 if (err <= 0)
1651 return err;
1652
1653 pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
1654 if (pad < 0)
1655 return pad;
1656
1657 darg->async = false;
1658 darg->skb = tls_strp_msg(ctx);
1659
1660 darg->zc &= !(prot->version == TLS_1_3_VERSION &&
1661 tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA);
1662
1663 rxm = strp_msg(darg->skb);
1664 rxm->full_len -= pad;
1665
1666 if (!darg->zc) {
1667
1668 darg->skb = tls_strp_msg_detach(ctx);
1669 if (!darg->skb)
1670 return -ENOMEM;
1671 } else {
1672 unsigned int off, len;
1673
1674
1675
1676
1677 off = rxm->offset + prot->prepend_size;
1678 len = rxm->full_len - prot->overhead_size;
1679
1680 err = skb_copy_datagram_msg(darg->skb, off, msg, len);
1681 if (err)
1682 return err;
1683 }
1684 return 1;
1685 }
1686
1687 static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
1688 struct tls_decrypt_arg *darg)
1689 {
1690 struct tls_context *tls_ctx = tls_get_ctx(sk);
1691 struct tls_prot_info *prot = &tls_ctx->prot_info;
1692 struct strp_msg *rxm;
1693 int err;
1694
1695 err = tls_decrypt_device(sk, msg, tls_ctx, darg);
1696 if (!err)
1697 err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
1698 if (err < 0)
1699 return err;
1700
1701 rxm = strp_msg(darg->skb);
1702 rxm->offset += prot->prepend_size;
1703 rxm->full_len -= prot->overhead_size;
1704 tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1705
1706 return 0;
1707 }
1708
1709 int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
1710 {
1711 struct tls_decrypt_arg darg = { .zc = true, };
1712
1713 return tls_decrypt_sg(sk, NULL, sgout, &darg);
1714 }
1715
1716 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
1717 u8 *control)
1718 {
1719 int err;
1720
1721 if (!*control) {
1722 *control = tlm->control;
1723 if (!*control)
1724 return -EBADMSG;
1725
1726 err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1727 sizeof(*control), control);
1728 if (*control != TLS_RECORD_TYPE_DATA) {
1729 if (err || msg->msg_flags & MSG_CTRUNC)
1730 return -EIO;
1731 }
1732 } else if (*control != tlm->control) {
1733 return 0;
1734 }
1735
1736 return 1;
1737 }
1738
1739 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
1740 {
1741 tls_strp_msg_done(&ctx->strp);
1742 }
1743
1744
1745
1746
1747
1748
1749 static int process_rx_list(struct tls_sw_context_rx *ctx,
1750 struct msghdr *msg,
1751 u8 *control,
1752 size_t skip,
1753 size_t len,
1754 bool is_peek)
1755 {
1756 struct sk_buff *skb = skb_peek(&ctx->rx_list);
1757 struct tls_msg *tlm;
1758 ssize_t copied = 0;
1759 int err;
1760
1761 while (skip && skb) {
1762 struct strp_msg *rxm = strp_msg(skb);
1763 tlm = tls_msg(skb);
1764
1765 err = tls_record_content_type(msg, tlm, control);
1766 if (err <= 0)
1767 goto out;
1768
1769 if (skip < rxm->full_len)
1770 break;
1771
1772 skip = skip - rxm->full_len;
1773 skb = skb_peek_next(skb, &ctx->rx_list);
1774 }
1775
1776 while (len && skb) {
1777 struct sk_buff *next_skb;
1778 struct strp_msg *rxm = strp_msg(skb);
1779 int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1780
1781 tlm = tls_msg(skb);
1782
1783 err = tls_record_content_type(msg, tlm, control);
1784 if (err <= 0)
1785 goto out;
1786
1787 err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1788 msg, chunk);
1789 if (err < 0)
1790 goto out;
1791
1792 len = len - chunk;
1793 copied = copied + chunk;
1794
1795
1796 if (!is_peek) {
1797 rxm->offset = rxm->offset + chunk;
1798 rxm->full_len = rxm->full_len - chunk;
1799
1800
1801 if (rxm->full_len - skip)
1802 break;
1803 }
1804
1805
1806
1807
1808 skip = 0;
1809
1810 if (msg)
1811 msg->msg_flags |= MSG_EOR;
1812
1813 next_skb = skb_peek_next(skb, &ctx->rx_list);
1814
1815 if (!is_peek) {
1816 __skb_unlink(skb, &ctx->rx_list);
1817 consume_skb(skb);
1818 }
1819
1820 skb = next_skb;
1821 }
1822 err = 0;
1823
1824 out:
1825 return copied ? : err;
1826 }
1827
1828 static bool
1829 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
1830 size_t len_left, size_t decrypted, ssize_t done,
1831 size_t *flushed_at)
1832 {
1833 size_t max_rec;
1834
1835 if (len_left <= decrypted)
1836 return false;
1837
1838 max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
1839 if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
1840 return false;
1841
1842 *flushed_at = done;
1843 return sk_flush_backlog(sk);
1844 }
1845
1846 static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
1847 bool nonblock)
1848 {
1849 long timeo;
1850 int err;
1851
1852 lock_sock(sk);
1853
1854 timeo = sock_rcvtimeo(sk, nonblock);
1855
1856 while (unlikely(ctx->reader_present)) {
1857 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1858
1859 ctx->reader_contended = 1;
1860
1861 add_wait_queue(&ctx->wq, &wait);
1862 sk_wait_event(sk, &timeo,
1863 !READ_ONCE(ctx->reader_present), &wait);
1864 remove_wait_queue(&ctx->wq, &wait);
1865
1866 if (timeo <= 0) {
1867 err = -EAGAIN;
1868 goto err_unlock;
1869 }
1870 if (signal_pending(current)) {
1871 err = sock_intr_errno(timeo);
1872 goto err_unlock;
1873 }
1874 }
1875
1876 WRITE_ONCE(ctx->reader_present, 1);
1877
1878 return 0;
1879
1880 err_unlock:
1881 release_sock(sk);
1882 return err;
1883 }
1884
1885 static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
1886 {
1887 if (unlikely(ctx->reader_contended)) {
1888 if (wq_has_sleeper(&ctx->wq))
1889 wake_up(&ctx->wq);
1890 else
1891 ctx->reader_contended = 0;
1892
1893 WARN_ON_ONCE(!ctx->reader_present);
1894 }
1895
1896 WRITE_ONCE(ctx->reader_present, 0);
1897 release_sock(sk);
1898 }
1899
1900 int tls_sw_recvmsg(struct sock *sk,
1901 struct msghdr *msg,
1902 size_t len,
1903 int flags,
1904 int *addr_len)
1905 {
1906 struct tls_context *tls_ctx = tls_get_ctx(sk);
1907 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1908 struct tls_prot_info *prot = &tls_ctx->prot_info;
1909 ssize_t decrypted = 0, async_copy_bytes = 0;
1910 struct sk_psock *psock;
1911 unsigned char control = 0;
1912 size_t flushed_at = 0;
1913 struct strp_msg *rxm;
1914 struct tls_msg *tlm;
1915 ssize_t copied = 0;
1916 bool async = false;
1917 int target, err;
1918 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1919 bool is_peek = flags & MSG_PEEK;
1920 bool released = true;
1921 bool bpf_strp_enabled;
1922 bool zc_capable;
1923
1924 if (unlikely(flags & MSG_ERRQUEUE))
1925 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1926
1927 psock = sk_psock_get(sk);
1928 err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
1929 if (err < 0)
1930 return err;
1931 bpf_strp_enabled = sk_psock_strp_enabled(psock);
1932
1933
1934 err = ctx->async_wait.err;
1935 if (err)
1936 goto end;
1937
1938
1939 err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
1940 if (err < 0)
1941 goto end;
1942
1943 copied = err;
1944 if (len <= copied)
1945 goto end;
1946
1947 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1948 len = len - copied;
1949
1950 zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
1951 ctx->zc_capable;
1952 decrypted = 0;
1953 while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) {
1954 struct tls_decrypt_arg darg;
1955 int to_decrypt, chunk;
1956
1957 err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
1958 released);
1959 if (err <= 0) {
1960 if (psock) {
1961 chunk = sk_msg_recvmsg(sk, psock, msg, len,
1962 flags);
1963 if (chunk > 0) {
1964 decrypted += chunk;
1965 len -= chunk;
1966 continue;
1967 }
1968 }
1969 goto recv_end;
1970 }
1971
1972 memset(&darg.inargs, 0, sizeof(darg.inargs));
1973
1974 rxm = strp_msg(tls_strp_msg(ctx));
1975 tlm = tls_msg(tls_strp_msg(ctx));
1976
1977 to_decrypt = rxm->full_len - prot->overhead_size;
1978
1979 if (zc_capable && to_decrypt <= len &&
1980 tlm->control == TLS_RECORD_TYPE_DATA)
1981 darg.zc = true;
1982
1983
1984 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1985 darg.async = ctx->async_capable;
1986 else
1987 darg.async = false;
1988
1989 err = tls_rx_one_record(sk, msg, &darg);
1990 if (err < 0) {
1991 tls_err_abort(sk, -EBADMSG);
1992 goto recv_end;
1993 }
1994
1995 async |= darg.async;
1996
1997
1998
1999
2000
2001
2002
2003
2004 err = tls_record_content_type(msg, tls_msg(darg.skb), &control);
2005 if (err <= 0) {
2006 DEBUG_NET_WARN_ON_ONCE(darg.zc);
2007 tls_rx_rec_done(ctx);
2008 put_on_rx_list_err:
2009 __skb_queue_tail(&ctx->rx_list, darg.skb);
2010 goto recv_end;
2011 }
2012
2013
2014 released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
2015 decrypted + copied,
2016 &flushed_at);
2017
2018
2019 rxm = strp_msg(darg.skb);
2020 chunk = rxm->full_len;
2021 tls_rx_rec_done(ctx);
2022
2023 if (!darg.zc) {
2024 bool partially_consumed = chunk > len;
2025 struct sk_buff *skb = darg.skb;
2026
2027 DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
2028
2029 if (async) {
2030
2031 chunk = min_t(int, to_decrypt, len);
2032 async_copy_bytes += chunk;
2033 put_on_rx_list:
2034 decrypted += chunk;
2035 len -= chunk;
2036 __skb_queue_tail(&ctx->rx_list, skb);
2037 continue;
2038 }
2039
2040 if (bpf_strp_enabled) {
2041 released = true;
2042 err = sk_psock_tls_strp_read(psock, skb);
2043 if (err != __SK_PASS) {
2044 rxm->offset = rxm->offset + rxm->full_len;
2045 rxm->full_len = 0;
2046 if (err == __SK_DROP)
2047 consume_skb(skb);
2048 continue;
2049 }
2050 }
2051
2052 if (partially_consumed)
2053 chunk = len;
2054
2055 err = skb_copy_datagram_msg(skb, rxm->offset,
2056 msg, chunk);
2057 if (err < 0)
2058 goto put_on_rx_list_err;
2059
2060 if (is_peek)
2061 goto put_on_rx_list;
2062
2063 if (partially_consumed) {
2064 rxm->offset += chunk;
2065 rxm->full_len -= chunk;
2066 goto put_on_rx_list;
2067 }
2068
2069 consume_skb(skb);
2070 }
2071
2072 decrypted += chunk;
2073 len -= chunk;
2074
2075
2076
2077
2078 msg->msg_flags |= MSG_EOR;
2079 if (control != TLS_RECORD_TYPE_DATA)
2080 break;
2081 }
2082
2083 recv_end:
2084 if (async) {
2085 int ret, pending;
2086
2087
2088 spin_lock_bh(&ctx->decrypt_compl_lock);
2089 reinit_completion(&ctx->async_wait.completion);
2090 pending = atomic_read(&ctx->decrypt_pending);
2091 spin_unlock_bh(&ctx->decrypt_compl_lock);
2092 ret = 0;
2093 if (pending)
2094 ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2095 __skb_queue_purge(&ctx->async_hold);
2096
2097 if (ret) {
2098 if (err >= 0 || err == -EINPROGRESS)
2099 err = ret;
2100 decrypted = 0;
2101 goto end;
2102 }
2103
2104
2105 if (is_peek || is_kvec)
2106 err = process_rx_list(ctx, msg, &control, copied,
2107 decrypted, is_peek);
2108 else
2109 err = process_rx_list(ctx, msg, &control, 0,
2110 async_copy_bytes, is_peek);
2111 decrypted = max(err, 0);
2112 }
2113
2114 copied += decrypted;
2115
2116 end:
2117 tls_rx_reader_unlock(sk, ctx);
2118 if (psock)
2119 sk_psock_put(sk, psock);
2120 return copied ? : err;
2121 }
2122
2123 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
2124 struct pipe_inode_info *pipe,
2125 size_t len, unsigned int flags)
2126 {
2127 struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
2128 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2129 struct strp_msg *rxm = NULL;
2130 struct sock *sk = sock->sk;
2131 struct tls_msg *tlm;
2132 struct sk_buff *skb;
2133 ssize_t copied = 0;
2134 int chunk;
2135 int err;
2136
2137 err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
2138 if (err < 0)
2139 return err;
2140
2141 if (!skb_queue_empty(&ctx->rx_list)) {
2142 skb = __skb_dequeue(&ctx->rx_list);
2143 } else {
2144 struct tls_decrypt_arg darg;
2145
2146 err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
2147 true);
2148 if (err <= 0)
2149 goto splice_read_end;
2150
2151 memset(&darg.inargs, 0, sizeof(darg.inargs));
2152
2153 err = tls_rx_one_record(sk, NULL, &darg);
2154 if (err < 0) {
2155 tls_err_abort(sk, -EBADMSG);
2156 goto splice_read_end;
2157 }
2158
2159 tls_rx_rec_done(ctx);
2160 skb = darg.skb;
2161 }
2162
2163 rxm = strp_msg(skb);
2164 tlm = tls_msg(skb);
2165
2166
2167 if (tlm->control != TLS_RECORD_TYPE_DATA) {
2168 err = -EINVAL;
2169 goto splice_requeue;
2170 }
2171
2172 chunk = min_t(unsigned int, rxm->full_len, len);
2173 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2174 if (copied < 0)
2175 goto splice_requeue;
2176
2177 if (chunk < rxm->full_len) {
2178 rxm->offset += len;
2179 rxm->full_len -= len;
2180 goto splice_requeue;
2181 }
2182
2183 consume_skb(skb);
2184
2185 splice_read_end:
2186 tls_rx_reader_unlock(sk, ctx);
2187 return copied ? : err;
2188
2189 splice_requeue:
2190 __skb_queue_head(&ctx->rx_list, skb);
2191 goto splice_read_end;
2192 }
2193
2194 bool tls_sw_sock_is_readable(struct sock *sk)
2195 {
2196 struct tls_context *tls_ctx = tls_get_ctx(sk);
2197 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2198 bool ingress_empty = true;
2199 struct sk_psock *psock;
2200
2201 rcu_read_lock();
2202 psock = sk_psock(sk);
2203 if (psock)
2204 ingress_empty = list_empty(&psock->ingress_msg);
2205 rcu_read_unlock();
2206
2207 return !ingress_empty || tls_strp_msg_ready(ctx) ||
2208 !skb_queue_empty(&ctx->rx_list);
2209 }
2210
2211 int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
2212 {
2213 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2214 struct tls_prot_info *prot = &tls_ctx->prot_info;
2215 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2216 size_t cipher_overhead;
2217 size_t data_len = 0;
2218 int ret;
2219
2220
2221 if (strp->stm.offset + prot->prepend_size > skb->len)
2222 return 0;
2223
2224
2225 if (WARN_ON(prot->prepend_size > sizeof(header))) {
2226 ret = -EINVAL;
2227 goto read_failure;
2228 }
2229
2230
2231 ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
2232 if (ret < 0)
2233 goto read_failure;
2234
2235 strp->mark = header[0];
2236
2237 data_len = ((header[4] & 0xFF) | (header[3] << 8));
2238
2239 cipher_overhead = prot->tag_size;
2240 if (prot->version != TLS_1_3_VERSION &&
2241 prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2242 cipher_overhead += prot->iv_size;
2243
2244 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2245 prot->tail_size) {
2246 ret = -EMSGSIZE;
2247 goto read_failure;
2248 }
2249 if (data_len < cipher_overhead) {
2250 ret = -EBADMSG;
2251 goto read_failure;
2252 }
2253
2254
2255 if (header[1] != TLS_1_2_VERSION_MINOR ||
2256 header[2] != TLS_1_2_VERSION_MAJOR) {
2257 ret = -EINVAL;
2258 goto read_failure;
2259 }
2260
2261 tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2262 TCP_SKB_CB(skb)->seq + strp->stm.offset);
2263 return data_len + TLS_HEADER_SIZE;
2264
2265 read_failure:
2266 tls_err_abort(strp->sk, ret);
2267
2268 return ret;
2269 }
2270
2271 void tls_rx_msg_ready(struct tls_strparser *strp)
2272 {
2273 struct tls_sw_context_rx *ctx;
2274
2275 ctx = container_of(strp, struct tls_sw_context_rx, strp);
2276 ctx->saved_data_ready(strp->sk);
2277 }
2278
2279 static void tls_data_ready(struct sock *sk)
2280 {
2281 struct tls_context *tls_ctx = tls_get_ctx(sk);
2282 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2283 struct sk_psock *psock;
2284
2285 tls_strp_data_ready(&ctx->strp);
2286
2287 psock = sk_psock_get(sk);
2288 if (psock) {
2289 if (!list_empty(&psock->ingress_msg))
2290 ctx->saved_data_ready(sk);
2291 sk_psock_put(sk, psock);
2292 }
2293 }
2294
2295 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2296 {
2297 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2298
2299 set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2300 set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2301 cancel_delayed_work_sync(&ctx->tx_work.work);
2302 }
2303
2304 void tls_sw_release_resources_tx(struct sock *sk)
2305 {
2306 struct tls_context *tls_ctx = tls_get_ctx(sk);
2307 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2308 struct tls_rec *rec, *tmp;
2309 int pending;
2310
2311
2312 spin_lock_bh(&ctx->encrypt_compl_lock);
2313 ctx->async_notify = true;
2314 pending = atomic_read(&ctx->encrypt_pending);
2315 spin_unlock_bh(&ctx->encrypt_compl_lock);
2316
2317 if (pending)
2318 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2319
2320 tls_tx_records(sk, -1);
2321
2322
2323
2324
2325 if (tls_ctx->partially_sent_record) {
2326 tls_free_partial_record(sk, tls_ctx);
2327 rec = list_first_entry(&ctx->tx_list,
2328 struct tls_rec, list);
2329 list_del(&rec->list);
2330 sk_msg_free(sk, &rec->msg_plaintext);
2331 kfree(rec);
2332 }
2333
2334 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2335 list_del(&rec->list);
2336 sk_msg_free(sk, &rec->msg_encrypted);
2337 sk_msg_free(sk, &rec->msg_plaintext);
2338 kfree(rec);
2339 }
2340
2341 crypto_free_aead(ctx->aead_send);
2342 tls_free_open_rec(sk);
2343 }
2344
2345 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2346 {
2347 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2348
2349 kfree(ctx);
2350 }
2351
2352 void tls_sw_release_resources_rx(struct sock *sk)
2353 {
2354 struct tls_context *tls_ctx = tls_get_ctx(sk);
2355 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2356
2357 kfree(tls_ctx->rx.rec_seq);
2358 kfree(tls_ctx->rx.iv);
2359
2360 if (ctx->aead_recv) {
2361 __skb_queue_purge(&ctx->rx_list);
2362 crypto_free_aead(ctx->aead_recv);
2363 tls_strp_stop(&ctx->strp);
2364
2365
2366
2367
2368 if (ctx->saved_data_ready) {
2369 write_lock_bh(&sk->sk_callback_lock);
2370 sk->sk_data_ready = ctx->saved_data_ready;
2371 write_unlock_bh(&sk->sk_callback_lock);
2372 }
2373 }
2374 }
2375
2376 void tls_sw_strparser_done(struct tls_context *tls_ctx)
2377 {
2378 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2379
2380 tls_strp_done(&ctx->strp);
2381 }
2382
2383 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2384 {
2385 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2386
2387 kfree(ctx);
2388 }
2389
2390 void tls_sw_free_resources_rx(struct sock *sk)
2391 {
2392 struct tls_context *tls_ctx = tls_get_ctx(sk);
2393
2394 tls_sw_release_resources_rx(sk);
2395 tls_sw_free_ctx_rx(tls_ctx);
2396 }
2397
2398
2399 static void tx_work_handler(struct work_struct *work)
2400 {
2401 struct delayed_work *delayed_work = to_delayed_work(work);
2402 struct tx_work *tx_work = container_of(delayed_work,
2403 struct tx_work, work);
2404 struct sock *sk = tx_work->sk;
2405 struct tls_context *tls_ctx = tls_get_ctx(sk);
2406 struct tls_sw_context_tx *ctx;
2407
2408 if (unlikely(!tls_ctx))
2409 return;
2410
2411 ctx = tls_sw_ctx_tx(tls_ctx);
2412 if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2413 return;
2414
2415 if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2416 return;
2417 mutex_lock(&tls_ctx->tx_lock);
2418 lock_sock(sk);
2419 tls_tx_records(sk, -1);
2420 release_sock(sk);
2421 mutex_unlock(&tls_ctx->tx_lock);
2422 }
2423
2424 static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
2425 {
2426 struct tls_rec *rec;
2427
2428 rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
2429 if (!rec)
2430 return false;
2431
2432 return READ_ONCE(rec->tx_ready);
2433 }
2434
2435 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2436 {
2437 struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2438
2439
2440 if (tls_is_tx_ready(tx_ctx) &&
2441 !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2442 schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2443 }
2444
2445 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2446 {
2447 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2448
2449 write_lock_bh(&sk->sk_callback_lock);
2450 rx_ctx->saved_data_ready = sk->sk_data_ready;
2451 sk->sk_data_ready = tls_data_ready;
2452 write_unlock_bh(&sk->sk_callback_lock);
2453 }
2454
2455 void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
2456 {
2457 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2458
2459 rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
2460 tls_ctx->prot_info.version != TLS_1_3_VERSION;
2461 }
2462
2463 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2464 {
2465 struct tls_context *tls_ctx = tls_get_ctx(sk);
2466 struct tls_prot_info *prot = &tls_ctx->prot_info;
2467 struct tls_crypto_info *crypto_info;
2468 struct tls_sw_context_tx *sw_ctx_tx = NULL;
2469 struct tls_sw_context_rx *sw_ctx_rx = NULL;
2470 struct cipher_context *cctx;
2471 struct crypto_aead **aead;
2472 u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2473 struct crypto_tfm *tfm;
2474 char *iv, *rec_seq, *key, *salt, *cipher_name;
2475 size_t keysize;
2476 int rc = 0;
2477
2478 if (!ctx) {
2479 rc = -EINVAL;
2480 goto out;
2481 }
2482
2483 if (tx) {
2484 if (!ctx->priv_ctx_tx) {
2485 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2486 if (!sw_ctx_tx) {
2487 rc = -ENOMEM;
2488 goto out;
2489 }
2490 ctx->priv_ctx_tx = sw_ctx_tx;
2491 } else {
2492 sw_ctx_tx =
2493 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2494 }
2495 } else {
2496 if (!ctx->priv_ctx_rx) {
2497 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2498 if (!sw_ctx_rx) {
2499 rc = -ENOMEM;
2500 goto out;
2501 }
2502 ctx->priv_ctx_rx = sw_ctx_rx;
2503 } else {
2504 sw_ctx_rx =
2505 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2506 }
2507 }
2508
2509 if (tx) {
2510 crypto_init_wait(&sw_ctx_tx->async_wait);
2511 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2512 crypto_info = &ctx->crypto_send.info;
2513 cctx = &ctx->tx;
2514 aead = &sw_ctx_tx->aead_send;
2515 INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2516 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2517 sw_ctx_tx->tx_work.sk = sk;
2518 } else {
2519 crypto_init_wait(&sw_ctx_rx->async_wait);
2520 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2521 init_waitqueue_head(&sw_ctx_rx->wq);
2522 crypto_info = &ctx->crypto_recv.info;
2523 cctx = &ctx->rx;
2524 skb_queue_head_init(&sw_ctx_rx->rx_list);
2525 skb_queue_head_init(&sw_ctx_rx->async_hold);
2526 aead = &sw_ctx_rx->aead_recv;
2527 }
2528
2529 switch (crypto_info->cipher_type) {
2530 case TLS_CIPHER_AES_GCM_128: {
2531 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2532
2533 gcm_128_info = (void *)crypto_info;
2534 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2535 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2536 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2537 iv = gcm_128_info->iv;
2538 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2539 rec_seq = gcm_128_info->rec_seq;
2540 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2541 key = gcm_128_info->key;
2542 salt = gcm_128_info->salt;
2543 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2544 cipher_name = "gcm(aes)";
2545 break;
2546 }
2547 case TLS_CIPHER_AES_GCM_256: {
2548 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2549
2550 gcm_256_info = (void *)crypto_info;
2551 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2552 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2553 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2554 iv = gcm_256_info->iv;
2555 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2556 rec_seq = gcm_256_info->rec_seq;
2557 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2558 key = gcm_256_info->key;
2559 salt = gcm_256_info->salt;
2560 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2561 cipher_name = "gcm(aes)";
2562 break;
2563 }
2564 case TLS_CIPHER_AES_CCM_128: {
2565 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2566
2567 ccm_128_info = (void *)crypto_info;
2568 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2569 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2570 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2571 iv = ccm_128_info->iv;
2572 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2573 rec_seq = ccm_128_info->rec_seq;
2574 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2575 key = ccm_128_info->key;
2576 salt = ccm_128_info->salt;
2577 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2578 cipher_name = "ccm(aes)";
2579 break;
2580 }
2581 case TLS_CIPHER_CHACHA20_POLY1305: {
2582 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
2583
2584 chacha20_poly1305_info = (void *)crypto_info;
2585 nonce_size = 0;
2586 tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
2587 iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
2588 iv = chacha20_poly1305_info->iv;
2589 rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
2590 rec_seq = chacha20_poly1305_info->rec_seq;
2591 keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
2592 key = chacha20_poly1305_info->key;
2593 salt = chacha20_poly1305_info->salt;
2594 salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
2595 cipher_name = "rfc7539(chacha20,poly1305)";
2596 break;
2597 }
2598 case TLS_CIPHER_SM4_GCM: {
2599 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;
2600
2601 sm4_gcm_info = (void *)crypto_info;
2602 nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2603 tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
2604 iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2605 iv = sm4_gcm_info->iv;
2606 rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
2607 rec_seq = sm4_gcm_info->rec_seq;
2608 keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
2609 key = sm4_gcm_info->key;
2610 salt = sm4_gcm_info->salt;
2611 salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
2612 cipher_name = "gcm(sm4)";
2613 break;
2614 }
2615 case TLS_CIPHER_SM4_CCM: {
2616 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;
2617
2618 sm4_ccm_info = (void *)crypto_info;
2619 nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2620 tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
2621 iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2622 iv = sm4_ccm_info->iv;
2623 rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
2624 rec_seq = sm4_ccm_info->rec_seq;
2625 keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
2626 key = sm4_ccm_info->key;
2627 salt = sm4_ccm_info->salt;
2628 salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
2629 cipher_name = "ccm(sm4)";
2630 break;
2631 }
2632 default:
2633 rc = -EINVAL;
2634 goto free_priv;
2635 }
2636
2637 if (crypto_info->version == TLS_1_3_VERSION) {
2638 nonce_size = 0;
2639 prot->aad_size = TLS_HEADER_SIZE;
2640 prot->tail_size = 1;
2641 } else {
2642 prot->aad_size = TLS_AAD_SPACE_SIZE;
2643 prot->tail_size = 0;
2644 }
2645
2646
2647 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2648 rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
2649 prot->aad_size > TLS_MAX_AAD_SIZE) {
2650 rc = -EINVAL;
2651 goto free_priv;
2652 }
2653
2654 prot->version = crypto_info->version;
2655 prot->cipher_type = crypto_info->cipher_type;
2656 prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2657 prot->tag_size = tag_size;
2658 prot->overhead_size = prot->prepend_size +
2659 prot->tag_size + prot->tail_size;
2660 prot->iv_size = iv_size;
2661 prot->salt_size = salt_size;
2662 cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2663 if (!cctx->iv) {
2664 rc = -ENOMEM;
2665 goto free_priv;
2666 }
2667
2668 prot->rec_seq_size = rec_seq_size;
2669 memcpy(cctx->iv, salt, salt_size);
2670 memcpy(cctx->iv + salt_size, iv, iv_size);
2671 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2672 if (!cctx->rec_seq) {
2673 rc = -ENOMEM;
2674 goto free_iv;
2675 }
2676
2677 if (!*aead) {
2678 *aead = crypto_alloc_aead(cipher_name, 0, 0);
2679 if (IS_ERR(*aead)) {
2680 rc = PTR_ERR(*aead);
2681 *aead = NULL;
2682 goto free_rec_seq;
2683 }
2684 }
2685
2686 ctx->push_pending_record = tls_sw_push_pending_record;
2687
2688 rc = crypto_aead_setkey(*aead, key, keysize);
2689
2690 if (rc)
2691 goto free_aead;
2692
2693 rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2694 if (rc)
2695 goto free_aead;
2696
2697 if (sw_ctx_rx) {
2698 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2699
2700 tls_update_rx_zc_capable(ctx);
2701 sw_ctx_rx->async_capable =
2702 crypto_info->version != TLS_1_3_VERSION &&
2703 !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
2704
2705 rc = tls_strp_init(&sw_ctx_rx->strp, sk);
2706 if (rc)
2707 goto free_aead;
2708 }
2709
2710 goto out;
2711
2712 free_aead:
2713 crypto_free_aead(*aead);
2714 *aead = NULL;
2715 free_rec_seq:
2716 kfree(cctx->rec_seq);
2717 cctx->rec_seq = NULL;
2718 free_iv:
2719 kfree(cctx->iv);
2720 cctx->iv = NULL;
2721 free_priv:
2722 if (tx) {
2723 kfree(ctx->priv_ctx_tx);
2724 ctx->priv_ctx_tx = NULL;
2725 } else {
2726 kfree(ctx->priv_ctx_rx);
2727 ctx->priv_ctx_rx = NULL;
2728 }
2729 out:
2730 return rc;
2731 }