Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */
0003 
0004 #include <linux/skbuff.h>
0005 #include <linux/workqueue.h>
0006 #include <net/strparser.h>
0007 #include <net/tcp.h>
0008 #include <net/sock.h>
0009 #include <net/tls.h>
0010 
0011 #include "tls.h"
0012 
0013 static struct workqueue_struct *tls_strp_wq;
0014 
0015 static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
0016 {
0017     if (strp->stopped)
0018         return;
0019 
0020     strp->stopped = 1;
0021 
0022     /* Report an error on the lower socket */
0023     strp->sk->sk_err = -err;
0024     sk_error_report(strp->sk);
0025 }
0026 
0027 static void tls_strp_anchor_free(struct tls_strparser *strp)
0028 {
0029     struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
0030 
0031     DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
0032     shinfo->frag_list = NULL;
0033     consume_skb(strp->anchor);
0034     strp->anchor = NULL;
0035 }
0036 
0037 /* Create a new skb with the contents of input copied to its page frags */
0038 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
0039 {
0040     struct strp_msg *rxm;
0041     struct sk_buff *skb;
0042     int i, err, offset;
0043 
0044     skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
0045                    &err, strp->sk->sk_allocation);
0046     if (!skb)
0047         return NULL;
0048 
0049     offset = strp->stm.offset;
0050     for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
0051         skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
0052 
0053         WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
0054                        skb_frag_address(frag),
0055                        skb_frag_size(frag)));
0056         offset += skb_frag_size(frag);
0057     }
0058 
0059     skb_copy_header(skb, strp->anchor);
0060     rxm = strp_msg(skb);
0061     rxm->offset = 0;
0062     return skb;
0063 }
0064 
0065 /* Steal the input skb, input msg is invalid after calling this function */
0066 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
0067 {
0068     struct tls_strparser *strp = &ctx->strp;
0069 
0070 #ifdef CONFIG_TLS_DEVICE
0071     DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
0072 #else
0073     /* This function turns an input into an output,
0074      * that can only happen if we have offload.
0075      */
0076     WARN_ON(1);
0077 #endif
0078 
0079     if (strp->copy_mode) {
0080         struct sk_buff *skb;
0081 
0082         /* Replace anchor with an empty skb, this is a little
0083          * dangerous but __tls_cur_msg() warns on empty skbs
0084          * so hopefully we'll catch abuses.
0085          */
0086         skb = alloc_skb(0, strp->sk->sk_allocation);
0087         if (!skb)
0088             return NULL;
0089 
0090         swap(strp->anchor, skb);
0091         return skb;
0092     }
0093 
0094     return tls_strp_msg_make_copy(strp);
0095 }
0096 
0097 /* Force the input skb to be in copy mode. The data ownership remains
0098  * with the input skb itself (meaning unpause will wipe it) but it can
0099  * be modified.
0100  */
0101 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
0102 {
0103     struct tls_strparser *strp = &ctx->strp;
0104     struct sk_buff *skb;
0105 
0106     if (strp->copy_mode)
0107         return 0;
0108 
0109     skb = tls_strp_msg_make_copy(strp);
0110     if (!skb)
0111         return -ENOMEM;
0112 
0113     tls_strp_anchor_free(strp);
0114     strp->anchor = skb;
0115 
0116     tcp_read_done(strp->sk, strp->stm.full_len);
0117     strp->copy_mode = 1;
0118 
0119     return 0;
0120 }
0121 
0122 /* Make a clone (in the skb sense) of the input msg to keep a reference
0123  * to the underlying data. The reference-holding skbs get placed on
0124  * @dst.
0125  */
0126 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
0127 {
0128     struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
0129 
0130     if (strp->copy_mode) {
0131         struct sk_buff *skb;
0132 
0133         WARN_ON_ONCE(!shinfo->nr_frags);
0134 
0135         /* We can't skb_clone() the anchor, it gets wiped by unpause */
0136         skb = alloc_skb(0, strp->sk->sk_allocation);
0137         if (!skb)
0138             return -ENOMEM;
0139 
0140         __skb_queue_tail(dst, strp->anchor);
0141         strp->anchor = skb;
0142     } else {
0143         struct sk_buff *iter, *clone;
0144         int chunk, len, offset;
0145 
0146         offset = strp->stm.offset;
0147         len = strp->stm.full_len;
0148         iter = shinfo->frag_list;
0149 
0150         while (len > 0) {
0151             if (iter->len <= offset) {
0152                 offset -= iter->len;
0153                 goto next;
0154             }
0155 
0156             chunk = iter->len - offset;
0157             offset = 0;
0158 
0159             clone = skb_clone(iter, strp->sk->sk_allocation);
0160             if (!clone)
0161                 return -ENOMEM;
0162             __skb_queue_tail(dst, clone);
0163 
0164             len -= chunk;
0165 next:
0166             iter = iter->next;
0167         }
0168     }
0169 
0170     return 0;
0171 }
0172 
0173 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
0174 {
0175     struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
0176     int i;
0177 
0178     DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
0179 
0180     for (i = 0; i < shinfo->nr_frags; i++)
0181         __skb_frag_unref(&shinfo->frags[i], false);
0182     shinfo->nr_frags = 0;
0183     strp->copy_mode = 0;
0184 }
0185 
0186 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
0187                unsigned int offset, size_t in_len)
0188 {
0189     struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
0190     struct sk_buff *skb;
0191     skb_frag_t *frag;
0192     size_t len, chunk;
0193     int sz;
0194 
0195     if (strp->msg_ready)
0196         return 0;
0197 
0198     skb = strp->anchor;
0199     frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
0200 
0201     len = in_len;
0202     /* First make sure we got the header */
0203     if (!strp->stm.full_len) {
0204         /* Assume one page is more than enough for headers */
0205         chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
0206         WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
0207                        skb_frag_address(frag) +
0208                        skb_frag_size(frag),
0209                        chunk));
0210 
0211         sz = tls_rx_msg_size(strp, strp->anchor);
0212         if (sz < 0) {
0213             desc->error = sz;
0214             return 0;
0215         }
0216 
0217         /* We may have over-read, sz == 0 is guaranteed under-read */
0218         if (sz > 0)
0219             chunk = min_t(size_t, chunk, sz - skb->len);
0220 
0221         skb->len += chunk;
0222         skb->data_len += chunk;
0223         skb_frag_size_add(frag, chunk);
0224         frag++;
0225         len -= chunk;
0226         offset += chunk;
0227 
0228         strp->stm.full_len = sz;
0229         if (!strp->stm.full_len)
0230             goto read_done;
0231     }
0232 
0233     /* Load up more data */
0234     while (len && strp->stm.full_len > skb->len) {
0235         chunk = min_t(size_t, len, strp->stm.full_len - skb->len);
0236         chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
0237         WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
0238                        skb_frag_address(frag) +
0239                        skb_frag_size(frag),
0240                        chunk));
0241 
0242         skb->len += chunk;
0243         skb->data_len += chunk;
0244         skb_frag_size_add(frag, chunk);
0245         frag++;
0246         len -= chunk;
0247         offset += chunk;
0248     }
0249 
0250     if (strp->stm.full_len == skb->len) {
0251         desc->count = 0;
0252 
0253         strp->msg_ready = 1;
0254         tls_rx_msg_ready(strp);
0255     }
0256 
0257 read_done:
0258     return in_len - len;
0259 }
0260 
0261 static int tls_strp_read_copyin(struct tls_strparser *strp)
0262 {
0263     struct socket *sock = strp->sk->sk_socket;
0264     read_descriptor_t desc;
0265 
0266     desc.arg.data = strp;
0267     desc.error = 0;
0268     desc.count = 1; /* give more than one skb per call */
0269 
0270     /* sk should be locked here, so okay to do read_sock */
0271     sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
0272 
0273     return desc.error;
0274 }
0275 
0276 static int tls_strp_read_short(struct tls_strparser *strp)
0277 {
0278     struct skb_shared_info *shinfo;
0279     struct page *page;
0280     int need_spc, len;
0281 
0282     /* If the rbuf is small or rcv window has collapsed to 0 we need
0283      * to read the data out. Otherwise the connection will stall.
0284      * Without pressure threshold of INT_MAX will never be ready.
0285      */
0286     if (likely(!tcp_epollin_ready(strp->sk, INT_MAX)))
0287         return 0;
0288 
0289     shinfo = skb_shinfo(strp->anchor);
0290     shinfo->frag_list = NULL;
0291 
0292     /* If we don't know the length go max plus page for cipher overhead */
0293     need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
0294 
0295     for (len = need_spc; len > 0; len -= PAGE_SIZE) {
0296         page = alloc_page(strp->sk->sk_allocation);
0297         if (!page) {
0298             tls_strp_flush_anchor_copy(strp);
0299             return -ENOMEM;
0300         }
0301 
0302         skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
0303                    page, 0, 0);
0304     }
0305 
0306     strp->copy_mode = 1;
0307     strp->stm.offset = 0;
0308 
0309     strp->anchor->len = 0;
0310     strp->anchor->data_len = 0;
0311     strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
0312 
0313     tls_strp_read_copyin(strp);
0314 
0315     return 0;
0316 }
0317 
0318 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
0319 {
0320     struct tcp_sock *tp = tcp_sk(strp->sk);
0321     struct sk_buff *first;
0322     u32 offset;
0323 
0324     first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
0325     if (WARN_ON_ONCE(!first))
0326         return;
0327 
0328     /* Bestow the state onto the anchor */
0329     strp->anchor->len = offset + len;
0330     strp->anchor->data_len = offset + len;
0331     strp->anchor->truesize = offset + len;
0332 
0333     skb_shinfo(strp->anchor)->frag_list = first;
0334 
0335     skb_copy_header(strp->anchor, first);
0336     strp->anchor->destructor = NULL;
0337 
0338     strp->stm.offset = offset;
0339 }
0340 
0341 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
0342 {
0343     struct strp_msg *rxm;
0344     struct tls_msg *tlm;
0345 
0346     DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
0347     DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
0348 
0349     if (!strp->copy_mode && force_refresh) {
0350         if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
0351             return;
0352 
0353         tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
0354     }
0355 
0356     rxm = strp_msg(strp->anchor);
0357     rxm->full_len   = strp->stm.full_len;
0358     rxm->offset = strp->stm.offset;
0359     tlm = tls_msg(strp->anchor);
0360     tlm->control    = strp->mark;
0361 }
0362 
0363 /* Called with lock held on lower socket */
0364 static int tls_strp_read_sock(struct tls_strparser *strp)
0365 {
0366     int sz, inq;
0367 
0368     inq = tcp_inq(strp->sk);
0369     if (inq < 1)
0370         return 0;
0371 
0372     if (unlikely(strp->copy_mode))
0373         return tls_strp_read_copyin(strp);
0374 
0375     if (inq < strp->stm.full_len)
0376         return tls_strp_read_short(strp);
0377 
0378     if (!strp->stm.full_len) {
0379         tls_strp_load_anchor_with_queue(strp, inq);
0380 
0381         sz = tls_rx_msg_size(strp, strp->anchor);
0382         if (sz < 0) {
0383             tls_strp_abort_strp(strp, sz);
0384             return sz;
0385         }
0386 
0387         strp->stm.full_len = sz;
0388 
0389         if (!strp->stm.full_len || inq < strp->stm.full_len)
0390             return tls_strp_read_short(strp);
0391     }
0392 
0393     strp->msg_ready = 1;
0394     tls_rx_msg_ready(strp);
0395 
0396     return 0;
0397 }
0398 
0399 void tls_strp_check_rcv(struct tls_strparser *strp)
0400 {
0401     if (unlikely(strp->stopped) || strp->msg_ready)
0402         return;
0403 
0404     if (tls_strp_read_sock(strp) == -ENOMEM)
0405         queue_work(tls_strp_wq, &strp->work);
0406 }
0407 
0408 /* Lower sock lock held */
0409 void tls_strp_data_ready(struct tls_strparser *strp)
0410 {
0411     /* This check is needed to synchronize with do_tls_strp_work.
0412      * do_tls_strp_work acquires a process lock (lock_sock) whereas
0413      * the lock held here is bh_lock_sock. The two locks can be
0414      * held by different threads at the same time, but bh_lock_sock
0415      * allows a thread in BH context to safely check if the process
0416      * lock is held. In this case, if the lock is held, queue work.
0417      */
0418     if (sock_owned_by_user_nocheck(strp->sk)) {
0419         queue_work(tls_strp_wq, &strp->work);
0420         return;
0421     }
0422 
0423     tls_strp_check_rcv(strp);
0424 }
0425 
0426 static void tls_strp_work(struct work_struct *w)
0427 {
0428     struct tls_strparser *strp =
0429         container_of(w, struct tls_strparser, work);
0430 
0431     lock_sock(strp->sk);
0432     tls_strp_check_rcv(strp);
0433     release_sock(strp->sk);
0434 }
0435 
0436 void tls_strp_msg_done(struct tls_strparser *strp)
0437 {
0438     WARN_ON(!strp->stm.full_len);
0439 
0440     if (likely(!strp->copy_mode))
0441         tcp_read_done(strp->sk, strp->stm.full_len);
0442     else
0443         tls_strp_flush_anchor_copy(strp);
0444 
0445     strp->msg_ready = 0;
0446     memset(&strp->stm, 0, sizeof(strp->stm));
0447 
0448     tls_strp_check_rcv(strp);
0449 }
0450 
0451 void tls_strp_stop(struct tls_strparser *strp)
0452 {
0453     strp->stopped = 1;
0454 }
0455 
0456 int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
0457 {
0458     memset(strp, 0, sizeof(*strp));
0459 
0460     strp->sk = sk;
0461 
0462     strp->anchor = alloc_skb(0, GFP_KERNEL);
0463     if (!strp->anchor)
0464         return -ENOMEM;
0465 
0466     INIT_WORK(&strp->work, tls_strp_work);
0467 
0468     return 0;
0469 }
0470 
0471 /* strp must already be stopped so that tls_strp_recv will no longer be called.
0472  * Note that tls_strp_done is not called with the lower socket held.
0473  */
0474 void tls_strp_done(struct tls_strparser *strp)
0475 {
0476     WARN_ON(!strp->stopped);
0477 
0478     cancel_work_sync(&strp->work);
0479     tls_strp_anchor_free(strp);
0480 }
0481 
0482 int __init tls_strp_dev_init(void)
0483 {
0484     tls_strp_wq = create_workqueue("tls-strp");
0485     if (unlikely(!tls_strp_wq))
0486         return -ENOMEM;
0487 
0488     return 0;
0489 }
0490 
0491 void tls_strp_dev_exit(void)
0492 {
0493     destroy_workqueue(tls_strp_wq);
0494 }