0001
0002
0003
0004
0005
0006
0007
0008 #include <linux/bpf.h>
0009 #include <linux/errno.h>
0010 #include <linux/errqueue.h>
0011 #include <linux/file.h>
0012 #include <linux/in.h>
0013 #include <linux/kernel.h>
0014 #include <linux/export.h>
0015 #include <linux/init.h>
0016 #include <linux/net.h>
0017 #include <linux/netdevice.h>
0018 #include <linux/poll.h>
0019 #include <linux/rculist.h>
0020 #include <linux/skbuff.h>
0021 #include <linux/socket.h>
0022 #include <linux/uaccess.h>
0023 #include <linux/workqueue.h>
0024 #include <net/strparser.h>
0025 #include <net/netns/generic.h>
0026 #include <net/sock.h>
0027
0028 static struct workqueue_struct *strp_wq;
0029
0030 static inline struct _strp_msg *_strp_msg(struct sk_buff *skb)
0031 {
0032 return (struct _strp_msg *)((void *)skb->cb +
0033 offsetof(struct sk_skb_cb, strp));
0034 }
0035
0036
0037 static void strp_abort_strp(struct strparser *strp, int err)
0038 {
0039
0040
0041 cancel_delayed_work(&strp->msg_timer_work);
0042
0043 if (strp->stopped)
0044 return;
0045
0046 strp->stopped = 1;
0047
0048 if (strp->sk) {
0049 struct sock *sk = strp->sk;
0050
0051
0052 sk->sk_err = -err;
0053 sk_error_report(sk);
0054 }
0055 }
0056
0057 static void strp_start_timer(struct strparser *strp, long timeo)
0058 {
0059 if (timeo && timeo != LONG_MAX)
0060 mod_delayed_work(strp_wq, &strp->msg_timer_work, timeo);
0061 }
0062
0063
0064 static void strp_parser_err(struct strparser *strp, int err,
0065 read_descriptor_t *desc)
0066 {
0067 desc->error = err;
0068 kfree_skb(strp->skb_head);
0069 strp->skb_head = NULL;
0070 strp->cb.abort_parser(strp, err);
0071 }
0072
0073 static inline int strp_peek_len(struct strparser *strp)
0074 {
0075 if (strp->sk) {
0076 struct socket *sock = strp->sk->sk_socket;
0077
0078 return sock->ops->peek_len(sock);
0079 }
0080
0081
0082
0083
0084
0085 return INT_MAX;
0086 }
0087
0088
0089 static int __strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
0090 unsigned int orig_offset, size_t orig_len,
0091 size_t max_msg_size, long timeo)
0092 {
0093 struct strparser *strp = (struct strparser *)desc->arg.data;
0094 struct _strp_msg *stm;
0095 struct sk_buff *head, *skb;
0096 size_t eaten = 0, cand_len;
0097 ssize_t extra;
0098 int err;
0099 bool cloned_orig = false;
0100
0101 if (strp->paused)
0102 return 0;
0103
0104 head = strp->skb_head;
0105 if (head) {
0106
0107 if (unlikely(orig_offset)) {
0108
0109
0110
0111
0112
0113 orig_skb = skb_clone(orig_skb, GFP_ATOMIC);
0114 if (!orig_skb) {
0115 STRP_STATS_INCR(strp->stats.mem_fail);
0116 desc->error = -ENOMEM;
0117 return 0;
0118 }
0119 if (!pskb_pull(orig_skb, orig_offset)) {
0120 STRP_STATS_INCR(strp->stats.mem_fail);
0121 kfree_skb(orig_skb);
0122 desc->error = -ENOMEM;
0123 return 0;
0124 }
0125 cloned_orig = true;
0126 orig_offset = 0;
0127 }
0128
0129 if (!strp->skb_nextp) {
0130
0131
0132
0133 err = skb_unclone(head, GFP_ATOMIC);
0134 if (err) {
0135 STRP_STATS_INCR(strp->stats.mem_fail);
0136 desc->error = err;
0137 return 0;
0138 }
0139
0140 if (unlikely(skb_shinfo(head)->frag_list)) {
0141
0142
0143
0144
0145
0146
0147 if (WARN_ON(head->next)) {
0148 desc->error = -EINVAL;
0149 return 0;
0150 }
0151
0152 skb = alloc_skb_for_msg(head);
0153 if (!skb) {
0154 STRP_STATS_INCR(strp->stats.mem_fail);
0155 desc->error = -ENOMEM;
0156 return 0;
0157 }
0158
0159 strp->skb_nextp = &head->next;
0160 strp->skb_head = skb;
0161 head = skb;
0162 } else {
0163 strp->skb_nextp =
0164 &skb_shinfo(head)->frag_list;
0165 }
0166 }
0167 }
0168
0169 while (eaten < orig_len) {
0170
0171 skb = skb_clone(orig_skb, GFP_ATOMIC);
0172 if (!skb) {
0173 STRP_STATS_INCR(strp->stats.mem_fail);
0174 desc->error = -ENOMEM;
0175 break;
0176 }
0177
0178 cand_len = orig_len - eaten;
0179
0180 head = strp->skb_head;
0181 if (!head) {
0182 head = skb;
0183 strp->skb_head = head;
0184
0185 strp->skb_nextp = NULL;
0186 stm = _strp_msg(head);
0187 memset(stm, 0, sizeof(*stm));
0188 stm->strp.offset = orig_offset + eaten;
0189 } else {
0190
0191
0192
0193 if (skb_has_frag_list(skb)) {
0194 err = skb_unclone(skb, GFP_ATOMIC);
0195 if (err) {
0196 STRP_STATS_INCR(strp->stats.mem_fail);
0197 desc->error = err;
0198 break;
0199 }
0200 }
0201
0202 stm = _strp_msg(head);
0203 *strp->skb_nextp = skb;
0204 strp->skb_nextp = &skb->next;
0205 head->data_len += skb->len;
0206 head->len += skb->len;
0207 head->truesize += skb->truesize;
0208 }
0209
0210 if (!stm->strp.full_len) {
0211 ssize_t len;
0212
0213 len = (*strp->cb.parse_msg)(strp, head);
0214
0215 if (!len) {
0216
0217 if (!stm->accum_len) {
0218
0219 strp_start_timer(strp, timeo);
0220 }
0221 stm->accum_len += cand_len;
0222 eaten += cand_len;
0223 STRP_STATS_INCR(strp->stats.need_more_hdr);
0224 WARN_ON(eaten != orig_len);
0225 break;
0226 } else if (len < 0) {
0227 if (len == -ESTRPIPE && stm->accum_len) {
0228 len = -ENODATA;
0229 strp->unrecov_intr = 1;
0230 } else {
0231 strp->interrupted = 1;
0232 }
0233 strp_parser_err(strp, len, desc);
0234 break;
0235 } else if (len > max_msg_size) {
0236
0237 STRP_STATS_INCR(strp->stats.msg_too_big);
0238 strp_parser_err(strp, -EMSGSIZE, desc);
0239 break;
0240 } else if (len <= (ssize_t)head->len -
0241 skb->len - stm->strp.offset) {
0242
0243
0244
0245 STRP_STATS_INCR(strp->stats.bad_hdr_len);
0246 strp_parser_err(strp, -EPROTO, desc);
0247 break;
0248 }
0249
0250 stm->strp.full_len = len;
0251 }
0252
0253 extra = (ssize_t)(stm->accum_len + cand_len) -
0254 stm->strp.full_len;
0255
0256 if (extra < 0) {
0257
0258 if (stm->strp.full_len - stm->accum_len >
0259 strp_peek_len(strp)) {
0260
0261
0262
0263
0264
0265
0266
0267 if (!stm->accum_len) {
0268
0269 strp_start_timer(strp, timeo);
0270 }
0271
0272 stm->accum_len += cand_len;
0273 eaten += cand_len;
0274 strp->need_bytes = stm->strp.full_len -
0275 stm->accum_len;
0276 STRP_STATS_ADD(strp->stats.bytes, cand_len);
0277 desc->count = 0;
0278 break;
0279 }
0280 stm->accum_len += cand_len;
0281 eaten += cand_len;
0282 WARN_ON(eaten != orig_len);
0283 break;
0284 }
0285
0286
0287
0288
0289
0290 WARN_ON(extra > cand_len);
0291
0292 eaten += (cand_len - extra);
0293
0294
0295 cancel_delayed_work(&strp->msg_timer_work);
0296 strp->skb_head = NULL;
0297 strp->need_bytes = 0;
0298 STRP_STATS_INCR(strp->stats.msgs);
0299
0300
0301 strp->cb.rcv_msg(strp, head);
0302
0303 if (unlikely(strp->paused)) {
0304
0305 break;
0306 }
0307 }
0308
0309 if (cloned_orig)
0310 kfree_skb(orig_skb);
0311
0312 STRP_STATS_ADD(strp->stats.bytes, eaten);
0313
0314 return eaten;
0315 }
0316
0317 int strp_process(struct strparser *strp, struct sk_buff *orig_skb,
0318 unsigned int orig_offset, size_t orig_len,
0319 size_t max_msg_size, long timeo)
0320 {
0321 read_descriptor_t desc;
0322
0323 desc.arg.data = strp;
0324
0325 return __strp_recv(&desc, orig_skb, orig_offset, orig_len,
0326 max_msg_size, timeo);
0327 }
0328 EXPORT_SYMBOL_GPL(strp_process);
0329
0330 static int strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
0331 unsigned int orig_offset, size_t orig_len)
0332 {
0333 struct strparser *strp = (struct strparser *)desc->arg.data;
0334
0335 return __strp_recv(desc, orig_skb, orig_offset, orig_len,
0336 strp->sk->sk_rcvbuf, strp->sk->sk_rcvtimeo);
0337 }
0338
0339 static int default_read_sock_done(struct strparser *strp, int err)
0340 {
0341 return err;
0342 }
0343
0344
0345 static int strp_read_sock(struct strparser *strp)
0346 {
0347 struct socket *sock = strp->sk->sk_socket;
0348 read_descriptor_t desc;
0349
0350 if (unlikely(!sock || !sock->ops || !sock->ops->read_sock))
0351 return -EBUSY;
0352
0353 desc.arg.data = strp;
0354 desc.error = 0;
0355 desc.count = 1;
0356
0357
0358 sock->ops->read_sock(strp->sk, &desc, strp_recv);
0359
0360 desc.error = strp->cb.read_sock_done(strp, desc.error);
0361
0362 return desc.error;
0363 }
0364
0365
0366 void strp_data_ready(struct strparser *strp)
0367 {
0368 if (unlikely(strp->stopped) || strp->paused)
0369 return;
0370
0371
0372
0373
0374
0375
0376
0377
0378 if (sock_owned_by_user_nocheck(strp->sk)) {
0379 queue_work(strp_wq, &strp->work);
0380 return;
0381 }
0382
0383 if (strp->need_bytes) {
0384 if (strp_peek_len(strp) < strp->need_bytes)
0385 return;
0386 }
0387
0388 if (strp_read_sock(strp) == -ENOMEM)
0389 queue_work(strp_wq, &strp->work);
0390 }
0391 EXPORT_SYMBOL_GPL(strp_data_ready);
0392
0393 static void do_strp_work(struct strparser *strp)
0394 {
0395
0396
0397
0398 strp->cb.lock(strp);
0399
0400 if (unlikely(strp->stopped))
0401 goto out;
0402
0403 if (strp->paused)
0404 goto out;
0405
0406 if (strp_read_sock(strp) == -ENOMEM)
0407 queue_work(strp_wq, &strp->work);
0408
0409 out:
0410 strp->cb.unlock(strp);
0411 }
0412
0413 static void strp_work(struct work_struct *w)
0414 {
0415 do_strp_work(container_of(w, struct strparser, work));
0416 }
0417
0418 static void strp_msg_timeout(struct work_struct *w)
0419 {
0420 struct strparser *strp = container_of(w, struct strparser,
0421 msg_timer_work.work);
0422
0423
0424 STRP_STATS_INCR(strp->stats.msg_timeouts);
0425 strp->cb.lock(strp);
0426 strp->cb.abort_parser(strp, -ETIMEDOUT);
0427 strp->cb.unlock(strp);
0428 }
0429
0430 static void strp_sock_lock(struct strparser *strp)
0431 {
0432 lock_sock(strp->sk);
0433 }
0434
0435 static void strp_sock_unlock(struct strparser *strp)
0436 {
0437 release_sock(strp->sk);
0438 }
0439
0440 int strp_init(struct strparser *strp, struct sock *sk,
0441 const struct strp_callbacks *cb)
0442 {
0443
0444 if (!cb || !cb->rcv_msg || !cb->parse_msg)
0445 return -EINVAL;
0446
0447
0448
0449
0450
0451
0452
0453
0454
0455
0456
0457
0458 if (!sk) {
0459 if (!cb->lock || !cb->unlock)
0460 return -EINVAL;
0461 }
0462
0463 memset(strp, 0, sizeof(*strp));
0464
0465 strp->sk = sk;
0466
0467 strp->cb.lock = cb->lock ? : strp_sock_lock;
0468 strp->cb.unlock = cb->unlock ? : strp_sock_unlock;
0469 strp->cb.rcv_msg = cb->rcv_msg;
0470 strp->cb.parse_msg = cb->parse_msg;
0471 strp->cb.read_sock_done = cb->read_sock_done ? : default_read_sock_done;
0472 strp->cb.abort_parser = cb->abort_parser ? : strp_abort_strp;
0473
0474 INIT_DELAYED_WORK(&strp->msg_timer_work, strp_msg_timeout);
0475 INIT_WORK(&strp->work, strp_work);
0476
0477 return 0;
0478 }
0479 EXPORT_SYMBOL_GPL(strp_init);
0480
0481
0482 void __strp_unpause(struct strparser *strp)
0483 {
0484 strp->paused = 0;
0485
0486 if (strp->need_bytes) {
0487 if (strp_peek_len(strp) < strp->need_bytes)
0488 return;
0489 }
0490 strp_read_sock(strp);
0491 }
0492 EXPORT_SYMBOL_GPL(__strp_unpause);
0493
0494 void strp_unpause(struct strparser *strp)
0495 {
0496 strp->paused = 0;
0497
0498
0499 smp_mb();
0500
0501 queue_work(strp_wq, &strp->work);
0502 }
0503 EXPORT_SYMBOL_GPL(strp_unpause);
0504
0505
0506
0507
0508 void strp_done(struct strparser *strp)
0509 {
0510 WARN_ON(!strp->stopped);
0511
0512 cancel_delayed_work_sync(&strp->msg_timer_work);
0513 cancel_work_sync(&strp->work);
0514
0515 if (strp->skb_head) {
0516 kfree_skb(strp->skb_head);
0517 strp->skb_head = NULL;
0518 }
0519 }
0520 EXPORT_SYMBOL_GPL(strp_done);
0521
0522 void strp_stop(struct strparser *strp)
0523 {
0524 strp->stopped = 1;
0525 }
0526 EXPORT_SYMBOL_GPL(strp_stop);
0527
0528 void strp_check_rcv(struct strparser *strp)
0529 {
0530 queue_work(strp_wq, &strp->work);
0531 }
0532 EXPORT_SYMBOL_GPL(strp_check_rcv);
0533
0534 static int __init strp_dev_init(void)
0535 {
0536 BUILD_BUG_ON(sizeof(struct sk_skb_cb) >
0537 sizeof_field(struct sk_buff, cb));
0538
0539 strp_wq = create_singlethread_workqueue("kstrp");
0540 if (unlikely(!strp_wq))
0541 return -ENOMEM;
0542
0543 return 0;
0544 }
0545 device_initcall(strp_dev_init);