0001
0002
0003
0004 #include <linux/skbuff.h>
0005 #include <linux/workqueue.h>
0006 #include <net/strparser.h>
0007 #include <net/tcp.h>
0008 #include <net/sock.h>
0009 #include <net/tls.h>
0010
0011 #include "tls.h"
0012
0013 static struct workqueue_struct *tls_strp_wq;
0014
0015 static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
0016 {
0017 if (strp->stopped)
0018 return;
0019
0020 strp->stopped = 1;
0021
0022
0023 strp->sk->sk_err = -err;
0024 sk_error_report(strp->sk);
0025 }
0026
0027 static void tls_strp_anchor_free(struct tls_strparser *strp)
0028 {
0029 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
0030
0031 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
0032 shinfo->frag_list = NULL;
0033 consume_skb(strp->anchor);
0034 strp->anchor = NULL;
0035 }
0036
0037
0038 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
0039 {
0040 struct strp_msg *rxm;
0041 struct sk_buff *skb;
0042 int i, err, offset;
0043
0044 skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
0045 &err, strp->sk->sk_allocation);
0046 if (!skb)
0047 return NULL;
0048
0049 offset = strp->stm.offset;
0050 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
0051 skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
0052
0053 WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
0054 skb_frag_address(frag),
0055 skb_frag_size(frag)));
0056 offset += skb_frag_size(frag);
0057 }
0058
0059 skb_copy_header(skb, strp->anchor);
0060 rxm = strp_msg(skb);
0061 rxm->offset = 0;
0062 return skb;
0063 }
0064
0065
0066 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
0067 {
0068 struct tls_strparser *strp = &ctx->strp;
0069
0070 #ifdef CONFIG_TLS_DEVICE
0071 DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
0072 #else
0073
0074
0075
0076 WARN_ON(1);
0077 #endif
0078
0079 if (strp->copy_mode) {
0080 struct sk_buff *skb;
0081
0082
0083
0084
0085
0086 skb = alloc_skb(0, strp->sk->sk_allocation);
0087 if (!skb)
0088 return NULL;
0089
0090 swap(strp->anchor, skb);
0091 return skb;
0092 }
0093
0094 return tls_strp_msg_make_copy(strp);
0095 }
0096
0097
0098
0099
0100
0101 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
0102 {
0103 struct tls_strparser *strp = &ctx->strp;
0104 struct sk_buff *skb;
0105
0106 if (strp->copy_mode)
0107 return 0;
0108
0109 skb = tls_strp_msg_make_copy(strp);
0110 if (!skb)
0111 return -ENOMEM;
0112
0113 tls_strp_anchor_free(strp);
0114 strp->anchor = skb;
0115
0116 tcp_read_done(strp->sk, strp->stm.full_len);
0117 strp->copy_mode = 1;
0118
0119 return 0;
0120 }
0121
0122
0123
0124
0125
0126 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
0127 {
0128 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
0129
0130 if (strp->copy_mode) {
0131 struct sk_buff *skb;
0132
0133 WARN_ON_ONCE(!shinfo->nr_frags);
0134
0135
0136 skb = alloc_skb(0, strp->sk->sk_allocation);
0137 if (!skb)
0138 return -ENOMEM;
0139
0140 __skb_queue_tail(dst, strp->anchor);
0141 strp->anchor = skb;
0142 } else {
0143 struct sk_buff *iter, *clone;
0144 int chunk, len, offset;
0145
0146 offset = strp->stm.offset;
0147 len = strp->stm.full_len;
0148 iter = shinfo->frag_list;
0149
0150 while (len > 0) {
0151 if (iter->len <= offset) {
0152 offset -= iter->len;
0153 goto next;
0154 }
0155
0156 chunk = iter->len - offset;
0157 offset = 0;
0158
0159 clone = skb_clone(iter, strp->sk->sk_allocation);
0160 if (!clone)
0161 return -ENOMEM;
0162 __skb_queue_tail(dst, clone);
0163
0164 len -= chunk;
0165 next:
0166 iter = iter->next;
0167 }
0168 }
0169
0170 return 0;
0171 }
0172
0173 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
0174 {
0175 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
0176 int i;
0177
0178 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
0179
0180 for (i = 0; i < shinfo->nr_frags; i++)
0181 __skb_frag_unref(&shinfo->frags[i], false);
0182 shinfo->nr_frags = 0;
0183 strp->copy_mode = 0;
0184 }
0185
0186 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
0187 unsigned int offset, size_t in_len)
0188 {
0189 struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
0190 struct sk_buff *skb;
0191 skb_frag_t *frag;
0192 size_t len, chunk;
0193 int sz;
0194
0195 if (strp->msg_ready)
0196 return 0;
0197
0198 skb = strp->anchor;
0199 frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
0200
0201 len = in_len;
0202
0203 if (!strp->stm.full_len) {
0204
0205 chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
0206 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
0207 skb_frag_address(frag) +
0208 skb_frag_size(frag),
0209 chunk));
0210
0211 sz = tls_rx_msg_size(strp, strp->anchor);
0212 if (sz < 0) {
0213 desc->error = sz;
0214 return 0;
0215 }
0216
0217
0218 if (sz > 0)
0219 chunk = min_t(size_t, chunk, sz - skb->len);
0220
0221 skb->len += chunk;
0222 skb->data_len += chunk;
0223 skb_frag_size_add(frag, chunk);
0224 frag++;
0225 len -= chunk;
0226 offset += chunk;
0227
0228 strp->stm.full_len = sz;
0229 if (!strp->stm.full_len)
0230 goto read_done;
0231 }
0232
0233
0234 while (len && strp->stm.full_len > skb->len) {
0235 chunk = min_t(size_t, len, strp->stm.full_len - skb->len);
0236 chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
0237 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
0238 skb_frag_address(frag) +
0239 skb_frag_size(frag),
0240 chunk));
0241
0242 skb->len += chunk;
0243 skb->data_len += chunk;
0244 skb_frag_size_add(frag, chunk);
0245 frag++;
0246 len -= chunk;
0247 offset += chunk;
0248 }
0249
0250 if (strp->stm.full_len == skb->len) {
0251 desc->count = 0;
0252
0253 strp->msg_ready = 1;
0254 tls_rx_msg_ready(strp);
0255 }
0256
0257 read_done:
0258 return in_len - len;
0259 }
0260
0261 static int tls_strp_read_copyin(struct tls_strparser *strp)
0262 {
0263 struct socket *sock = strp->sk->sk_socket;
0264 read_descriptor_t desc;
0265
0266 desc.arg.data = strp;
0267 desc.error = 0;
0268 desc.count = 1;
0269
0270
0271 sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
0272
0273 return desc.error;
0274 }
0275
0276 static int tls_strp_read_short(struct tls_strparser *strp)
0277 {
0278 struct skb_shared_info *shinfo;
0279 struct page *page;
0280 int need_spc, len;
0281
0282
0283
0284
0285
0286 if (likely(!tcp_epollin_ready(strp->sk, INT_MAX)))
0287 return 0;
0288
0289 shinfo = skb_shinfo(strp->anchor);
0290 shinfo->frag_list = NULL;
0291
0292
0293 need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
0294
0295 for (len = need_spc; len > 0; len -= PAGE_SIZE) {
0296 page = alloc_page(strp->sk->sk_allocation);
0297 if (!page) {
0298 tls_strp_flush_anchor_copy(strp);
0299 return -ENOMEM;
0300 }
0301
0302 skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
0303 page, 0, 0);
0304 }
0305
0306 strp->copy_mode = 1;
0307 strp->stm.offset = 0;
0308
0309 strp->anchor->len = 0;
0310 strp->anchor->data_len = 0;
0311 strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
0312
0313 tls_strp_read_copyin(strp);
0314
0315 return 0;
0316 }
0317
0318 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
0319 {
0320 struct tcp_sock *tp = tcp_sk(strp->sk);
0321 struct sk_buff *first;
0322 u32 offset;
0323
0324 first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
0325 if (WARN_ON_ONCE(!first))
0326 return;
0327
0328
0329 strp->anchor->len = offset + len;
0330 strp->anchor->data_len = offset + len;
0331 strp->anchor->truesize = offset + len;
0332
0333 skb_shinfo(strp->anchor)->frag_list = first;
0334
0335 skb_copy_header(strp->anchor, first);
0336 strp->anchor->destructor = NULL;
0337
0338 strp->stm.offset = offset;
0339 }
0340
0341 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
0342 {
0343 struct strp_msg *rxm;
0344 struct tls_msg *tlm;
0345
0346 DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
0347 DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
0348
0349 if (!strp->copy_mode && force_refresh) {
0350 if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
0351 return;
0352
0353 tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
0354 }
0355
0356 rxm = strp_msg(strp->anchor);
0357 rxm->full_len = strp->stm.full_len;
0358 rxm->offset = strp->stm.offset;
0359 tlm = tls_msg(strp->anchor);
0360 tlm->control = strp->mark;
0361 }
0362
0363
0364 static int tls_strp_read_sock(struct tls_strparser *strp)
0365 {
0366 int sz, inq;
0367
0368 inq = tcp_inq(strp->sk);
0369 if (inq < 1)
0370 return 0;
0371
0372 if (unlikely(strp->copy_mode))
0373 return tls_strp_read_copyin(strp);
0374
0375 if (inq < strp->stm.full_len)
0376 return tls_strp_read_short(strp);
0377
0378 if (!strp->stm.full_len) {
0379 tls_strp_load_anchor_with_queue(strp, inq);
0380
0381 sz = tls_rx_msg_size(strp, strp->anchor);
0382 if (sz < 0) {
0383 tls_strp_abort_strp(strp, sz);
0384 return sz;
0385 }
0386
0387 strp->stm.full_len = sz;
0388
0389 if (!strp->stm.full_len || inq < strp->stm.full_len)
0390 return tls_strp_read_short(strp);
0391 }
0392
0393 strp->msg_ready = 1;
0394 tls_rx_msg_ready(strp);
0395
0396 return 0;
0397 }
0398
0399 void tls_strp_check_rcv(struct tls_strparser *strp)
0400 {
0401 if (unlikely(strp->stopped) || strp->msg_ready)
0402 return;
0403
0404 if (tls_strp_read_sock(strp) == -ENOMEM)
0405 queue_work(tls_strp_wq, &strp->work);
0406 }
0407
0408
0409 void tls_strp_data_ready(struct tls_strparser *strp)
0410 {
0411
0412
0413
0414
0415
0416
0417
0418 if (sock_owned_by_user_nocheck(strp->sk)) {
0419 queue_work(tls_strp_wq, &strp->work);
0420 return;
0421 }
0422
0423 tls_strp_check_rcv(strp);
0424 }
0425
0426 static void tls_strp_work(struct work_struct *w)
0427 {
0428 struct tls_strparser *strp =
0429 container_of(w, struct tls_strparser, work);
0430
0431 lock_sock(strp->sk);
0432 tls_strp_check_rcv(strp);
0433 release_sock(strp->sk);
0434 }
0435
0436 void tls_strp_msg_done(struct tls_strparser *strp)
0437 {
0438 WARN_ON(!strp->stm.full_len);
0439
0440 if (likely(!strp->copy_mode))
0441 tcp_read_done(strp->sk, strp->stm.full_len);
0442 else
0443 tls_strp_flush_anchor_copy(strp);
0444
0445 strp->msg_ready = 0;
0446 memset(&strp->stm, 0, sizeof(strp->stm));
0447
0448 tls_strp_check_rcv(strp);
0449 }
0450
0451 void tls_strp_stop(struct tls_strparser *strp)
0452 {
0453 strp->stopped = 1;
0454 }
0455
0456 int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
0457 {
0458 memset(strp, 0, sizeof(*strp));
0459
0460 strp->sk = sk;
0461
0462 strp->anchor = alloc_skb(0, GFP_KERNEL);
0463 if (!strp->anchor)
0464 return -ENOMEM;
0465
0466 INIT_WORK(&strp->work, tls_strp_work);
0467
0468 return 0;
0469 }
0470
0471
0472
0473
0474 void tls_strp_done(struct tls_strparser *strp)
0475 {
0476 WARN_ON(!strp->stopped);
0477
0478 cancel_work_sync(&strp->work);
0479 tls_strp_anchor_free(strp);
0480 }
0481
0482 int __init tls_strp_dev_init(void)
0483 {
0484 tls_strp_wq = create_workqueue("tls-strp");
0485 if (unlikely(!tls_strp_wq))
0486 return -ENOMEM;
0487
0488 return 0;
0489 }
0490
0491 void tls_strp_dev_exit(void)
0492 {
0493 destroy_workqueue(tls_strp_wq);
0494 }