0001
0002
0003
0004
0005 #include <linux/filter.h>
0006 #include <linux/kernel.h>
0007 #include <linux/module.h>
0008 #include <linux/skbuff.h>
0009 #include <linux/types.h>
0010 #include <linux/bpf.h>
0011 #include <net/lwtunnel.h>
0012 #include <net/gre.h>
0013 #include <net/ip6_route.h>
0014 #include <net/ipv6_stubs.h>
0015
0016 struct bpf_lwt_prog {
0017 struct bpf_prog *prog;
0018 char *name;
0019 };
0020
0021 struct bpf_lwt {
0022 struct bpf_lwt_prog in;
0023 struct bpf_lwt_prog out;
0024 struct bpf_lwt_prog xmit;
0025 int family;
0026 };
0027
0028 #define MAX_PROG_NAME 256
0029
0030 static inline struct bpf_lwt *bpf_lwt_lwtunnel(struct lwtunnel_state *lwt)
0031 {
0032 return (struct bpf_lwt *)lwt->data;
0033 }
0034
0035 #define NO_REDIRECT false
0036 #define CAN_REDIRECT true
0037
0038 static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt,
0039 struct dst_entry *dst, bool can_redirect)
0040 {
0041 int ret;
0042
0043
0044
0045
0046 migrate_disable();
0047 local_bh_disable();
0048 bpf_compute_data_pointers(skb);
0049 ret = bpf_prog_run_save_cb(lwt->prog, skb);
0050
0051 switch (ret) {
0052 case BPF_OK:
0053 case BPF_LWT_REROUTE:
0054 break;
0055
0056 case BPF_REDIRECT:
0057 if (unlikely(!can_redirect)) {
0058 pr_warn_once("Illegal redirect return code in prog %s\n",
0059 lwt->name ? : "<unknown>");
0060 ret = BPF_OK;
0061 } else {
0062 skb_reset_mac_header(skb);
0063 ret = skb_do_redirect(skb);
0064 if (ret == 0)
0065 ret = BPF_REDIRECT;
0066 }
0067 break;
0068
0069 case BPF_DROP:
0070 kfree_skb(skb);
0071 ret = -EPERM;
0072 break;
0073
0074 default:
0075 pr_warn_once("bpf-lwt: Illegal return value %u, expect packet loss\n", ret);
0076 kfree_skb(skb);
0077 ret = -EINVAL;
0078 break;
0079 }
0080
0081 local_bh_enable();
0082 migrate_enable();
0083
0084 return ret;
0085 }
0086
0087 static int bpf_lwt_input_reroute(struct sk_buff *skb)
0088 {
0089 int err = -EINVAL;
0090
0091 if (skb->protocol == htons(ETH_P_IP)) {
0092 struct net_device *dev = skb_dst(skb)->dev;
0093 struct iphdr *iph = ip_hdr(skb);
0094
0095 dev_hold(dev);
0096 skb_dst_drop(skb);
0097 err = ip_route_input_noref(skb, iph->daddr, iph->saddr,
0098 iph->tos, dev);
0099 dev_put(dev);
0100 } else if (skb->protocol == htons(ETH_P_IPV6)) {
0101 skb_dst_drop(skb);
0102 err = ipv6_stub->ipv6_route_input(skb);
0103 } else {
0104 err = -EAFNOSUPPORT;
0105 }
0106
0107 if (err)
0108 goto err;
0109 return dst_input(skb);
0110
0111 err:
0112 kfree_skb(skb);
0113 return err;
0114 }
0115
0116 static int bpf_input(struct sk_buff *skb)
0117 {
0118 struct dst_entry *dst = skb_dst(skb);
0119 struct bpf_lwt *bpf;
0120 int ret;
0121
0122 bpf = bpf_lwt_lwtunnel(dst->lwtstate);
0123 if (bpf->in.prog) {
0124 ret = run_lwt_bpf(skb, &bpf->in, dst, NO_REDIRECT);
0125 if (ret < 0)
0126 return ret;
0127 if (ret == BPF_LWT_REROUTE)
0128 return bpf_lwt_input_reroute(skb);
0129 }
0130
0131 if (unlikely(!dst->lwtstate->orig_input)) {
0132 kfree_skb(skb);
0133 return -EINVAL;
0134 }
0135
0136 return dst->lwtstate->orig_input(skb);
0137 }
0138
0139 static int bpf_output(struct net *net, struct sock *sk, struct sk_buff *skb)
0140 {
0141 struct dst_entry *dst = skb_dst(skb);
0142 struct bpf_lwt *bpf;
0143 int ret;
0144
0145 bpf = bpf_lwt_lwtunnel(dst->lwtstate);
0146 if (bpf->out.prog) {
0147 ret = run_lwt_bpf(skb, &bpf->out, dst, NO_REDIRECT);
0148 if (ret < 0)
0149 return ret;
0150 }
0151
0152 if (unlikely(!dst->lwtstate->orig_output)) {
0153 pr_warn_once("orig_output not set on dst for prog %s\n",
0154 bpf->out.name);
0155 kfree_skb(skb);
0156 return -EINVAL;
0157 }
0158
0159 return dst->lwtstate->orig_output(net, sk, skb);
0160 }
0161
0162 static int xmit_check_hhlen(struct sk_buff *skb, int hh_len)
0163 {
0164 if (skb_headroom(skb) < hh_len) {
0165 int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb));
0166
0167 if (pskb_expand_head(skb, nhead, 0, GFP_ATOMIC))
0168 return -ENOMEM;
0169 }
0170
0171 return 0;
0172 }
0173
0174 static int bpf_lwt_xmit_reroute(struct sk_buff *skb)
0175 {
0176 struct net_device *l3mdev = l3mdev_master_dev_rcu(skb_dst(skb)->dev);
0177 int oif = l3mdev ? l3mdev->ifindex : 0;
0178 struct dst_entry *dst = NULL;
0179 int err = -EAFNOSUPPORT;
0180 struct sock *sk;
0181 struct net *net;
0182 bool ipv4;
0183
0184 if (skb->protocol == htons(ETH_P_IP))
0185 ipv4 = true;
0186 else if (skb->protocol == htons(ETH_P_IPV6))
0187 ipv4 = false;
0188 else
0189 goto err;
0190
0191 sk = sk_to_full_sk(skb->sk);
0192 if (sk) {
0193 if (sk->sk_bound_dev_if)
0194 oif = sk->sk_bound_dev_if;
0195 net = sock_net(sk);
0196 } else {
0197 net = dev_net(skb_dst(skb)->dev);
0198 }
0199
0200 if (ipv4) {
0201 struct iphdr *iph = ip_hdr(skb);
0202 struct flowi4 fl4 = {};
0203 struct rtable *rt;
0204
0205 fl4.flowi4_oif = oif;
0206 fl4.flowi4_mark = skb->mark;
0207 fl4.flowi4_uid = sock_net_uid(net, sk);
0208 fl4.flowi4_tos = RT_TOS(iph->tos);
0209 fl4.flowi4_flags = FLOWI_FLAG_ANYSRC;
0210 fl4.flowi4_proto = iph->protocol;
0211 fl4.daddr = iph->daddr;
0212 fl4.saddr = iph->saddr;
0213
0214 rt = ip_route_output_key(net, &fl4);
0215 if (IS_ERR(rt)) {
0216 err = PTR_ERR(rt);
0217 goto err;
0218 }
0219 dst = &rt->dst;
0220 } else {
0221 struct ipv6hdr *iph6 = ipv6_hdr(skb);
0222 struct flowi6 fl6 = {};
0223
0224 fl6.flowi6_oif = oif;
0225 fl6.flowi6_mark = skb->mark;
0226 fl6.flowi6_uid = sock_net_uid(net, sk);
0227 fl6.flowlabel = ip6_flowinfo(iph6);
0228 fl6.flowi6_proto = iph6->nexthdr;
0229 fl6.daddr = iph6->daddr;
0230 fl6.saddr = iph6->saddr;
0231
0232 dst = ipv6_stub->ipv6_dst_lookup_flow(net, skb->sk, &fl6, NULL);
0233 if (IS_ERR(dst)) {
0234 err = PTR_ERR(dst);
0235 goto err;
0236 }
0237 }
0238 if (unlikely(dst->error)) {
0239 err = dst->error;
0240 dst_release(dst);
0241 goto err;
0242 }
0243
0244
0245
0246
0247
0248
0249 err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev));
0250 if (unlikely(err))
0251 goto err;
0252
0253 skb_dst_drop(skb);
0254 skb_dst_set(skb, dst);
0255
0256 err = dst_output(dev_net(skb_dst(skb)->dev), skb->sk, skb);
0257 if (unlikely(err))
0258 return err;
0259
0260
0261 return LWTUNNEL_XMIT_DONE;
0262
0263 err:
0264 kfree_skb(skb);
0265 return err;
0266 }
0267
0268 static int bpf_xmit(struct sk_buff *skb)
0269 {
0270 struct dst_entry *dst = skb_dst(skb);
0271 struct bpf_lwt *bpf;
0272
0273 bpf = bpf_lwt_lwtunnel(dst->lwtstate);
0274 if (bpf->xmit.prog) {
0275 int hh_len = dst->dev->hard_header_len;
0276 __be16 proto = skb->protocol;
0277 int ret;
0278
0279 ret = run_lwt_bpf(skb, &bpf->xmit, dst, CAN_REDIRECT);
0280 switch (ret) {
0281 case BPF_OK:
0282
0283
0284
0285
0286 if (skb->protocol != proto) {
0287 kfree_skb(skb);
0288 return -EINVAL;
0289 }
0290
0291
0292
0293 ret = xmit_check_hhlen(skb, hh_len);
0294 if (unlikely(ret))
0295 return ret;
0296
0297 return LWTUNNEL_XMIT_CONTINUE;
0298 case BPF_REDIRECT:
0299 return LWTUNNEL_XMIT_DONE;
0300 case BPF_LWT_REROUTE:
0301 return bpf_lwt_xmit_reroute(skb);
0302 default:
0303 return ret;
0304 }
0305 }
0306
0307 return LWTUNNEL_XMIT_CONTINUE;
0308 }
0309
0310 static void bpf_lwt_prog_destroy(struct bpf_lwt_prog *prog)
0311 {
0312 if (prog->prog)
0313 bpf_prog_put(prog->prog);
0314
0315 kfree(prog->name);
0316 }
0317
0318 static void bpf_destroy_state(struct lwtunnel_state *lwt)
0319 {
0320 struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
0321
0322 bpf_lwt_prog_destroy(&bpf->in);
0323 bpf_lwt_prog_destroy(&bpf->out);
0324 bpf_lwt_prog_destroy(&bpf->xmit);
0325 }
0326
0327 static const struct nla_policy bpf_prog_policy[LWT_BPF_PROG_MAX + 1] = {
0328 [LWT_BPF_PROG_FD] = { .type = NLA_U32, },
0329 [LWT_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
0330 .len = MAX_PROG_NAME },
0331 };
0332
0333 static int bpf_parse_prog(struct nlattr *attr, struct bpf_lwt_prog *prog,
0334 enum bpf_prog_type type)
0335 {
0336 struct nlattr *tb[LWT_BPF_PROG_MAX + 1];
0337 struct bpf_prog *p;
0338 int ret;
0339 u32 fd;
0340
0341 ret = nla_parse_nested_deprecated(tb, LWT_BPF_PROG_MAX, attr,
0342 bpf_prog_policy, NULL);
0343 if (ret < 0)
0344 return ret;
0345
0346 if (!tb[LWT_BPF_PROG_FD] || !tb[LWT_BPF_PROG_NAME])
0347 return -EINVAL;
0348
0349 prog->name = nla_memdup(tb[LWT_BPF_PROG_NAME], GFP_ATOMIC);
0350 if (!prog->name)
0351 return -ENOMEM;
0352
0353 fd = nla_get_u32(tb[LWT_BPF_PROG_FD]);
0354 p = bpf_prog_get_type(fd, type);
0355 if (IS_ERR(p))
0356 return PTR_ERR(p);
0357
0358 prog->prog = p;
0359
0360 return 0;
0361 }
0362
0363 static const struct nla_policy bpf_nl_policy[LWT_BPF_MAX + 1] = {
0364 [LWT_BPF_IN] = { .type = NLA_NESTED, },
0365 [LWT_BPF_OUT] = { .type = NLA_NESTED, },
0366 [LWT_BPF_XMIT] = { .type = NLA_NESTED, },
0367 [LWT_BPF_XMIT_HEADROOM] = { .type = NLA_U32 },
0368 };
0369
0370 static int bpf_build_state(struct net *net, struct nlattr *nla,
0371 unsigned int family, const void *cfg,
0372 struct lwtunnel_state **ts,
0373 struct netlink_ext_ack *extack)
0374 {
0375 struct nlattr *tb[LWT_BPF_MAX + 1];
0376 struct lwtunnel_state *newts;
0377 struct bpf_lwt *bpf;
0378 int ret;
0379
0380 if (family != AF_INET && family != AF_INET6)
0381 return -EAFNOSUPPORT;
0382
0383 ret = nla_parse_nested_deprecated(tb, LWT_BPF_MAX, nla, bpf_nl_policy,
0384 extack);
0385 if (ret < 0)
0386 return ret;
0387
0388 if (!tb[LWT_BPF_IN] && !tb[LWT_BPF_OUT] && !tb[LWT_BPF_XMIT])
0389 return -EINVAL;
0390
0391 newts = lwtunnel_state_alloc(sizeof(*bpf));
0392 if (!newts)
0393 return -ENOMEM;
0394
0395 newts->type = LWTUNNEL_ENCAP_BPF;
0396 bpf = bpf_lwt_lwtunnel(newts);
0397
0398 if (tb[LWT_BPF_IN]) {
0399 newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
0400 ret = bpf_parse_prog(tb[LWT_BPF_IN], &bpf->in,
0401 BPF_PROG_TYPE_LWT_IN);
0402 if (ret < 0)
0403 goto errout;
0404 }
0405
0406 if (tb[LWT_BPF_OUT]) {
0407 newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
0408 ret = bpf_parse_prog(tb[LWT_BPF_OUT], &bpf->out,
0409 BPF_PROG_TYPE_LWT_OUT);
0410 if (ret < 0)
0411 goto errout;
0412 }
0413
0414 if (tb[LWT_BPF_XMIT]) {
0415 newts->flags |= LWTUNNEL_STATE_XMIT_REDIRECT;
0416 ret = bpf_parse_prog(tb[LWT_BPF_XMIT], &bpf->xmit,
0417 BPF_PROG_TYPE_LWT_XMIT);
0418 if (ret < 0)
0419 goto errout;
0420 }
0421
0422 if (tb[LWT_BPF_XMIT_HEADROOM]) {
0423 u32 headroom = nla_get_u32(tb[LWT_BPF_XMIT_HEADROOM]);
0424
0425 if (headroom > LWT_BPF_MAX_HEADROOM) {
0426 ret = -ERANGE;
0427 goto errout;
0428 }
0429
0430 newts->headroom = headroom;
0431 }
0432
0433 bpf->family = family;
0434 *ts = newts;
0435
0436 return 0;
0437
0438 errout:
0439 bpf_destroy_state(newts);
0440 kfree(newts);
0441 return ret;
0442 }
0443
0444 static int bpf_fill_lwt_prog(struct sk_buff *skb, int attr,
0445 struct bpf_lwt_prog *prog)
0446 {
0447 struct nlattr *nest;
0448
0449 if (!prog->prog)
0450 return 0;
0451
0452 nest = nla_nest_start_noflag(skb, attr);
0453 if (!nest)
0454 return -EMSGSIZE;
0455
0456 if (prog->name &&
0457 nla_put_string(skb, LWT_BPF_PROG_NAME, prog->name))
0458 return -EMSGSIZE;
0459
0460 return nla_nest_end(skb, nest);
0461 }
0462
0463 static int bpf_fill_encap_info(struct sk_buff *skb, struct lwtunnel_state *lwt)
0464 {
0465 struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
0466
0467 if (bpf_fill_lwt_prog(skb, LWT_BPF_IN, &bpf->in) < 0 ||
0468 bpf_fill_lwt_prog(skb, LWT_BPF_OUT, &bpf->out) < 0 ||
0469 bpf_fill_lwt_prog(skb, LWT_BPF_XMIT, &bpf->xmit) < 0)
0470 return -EMSGSIZE;
0471
0472 return 0;
0473 }
0474
0475 static int bpf_encap_nlsize(struct lwtunnel_state *lwtstate)
0476 {
0477 int nest_len = nla_total_size(sizeof(struct nlattr)) +
0478 nla_total_size(MAX_PROG_NAME) +
0479 0;
0480
0481 return nest_len +
0482 nest_len +
0483 nest_len +
0484 0;
0485 }
0486
0487 static int bpf_lwt_prog_cmp(struct bpf_lwt_prog *a, struct bpf_lwt_prog *b)
0488 {
0489
0490
0491
0492
0493 if (!a->name && !b->name)
0494 return 0;
0495
0496 if (!a->name || !b->name)
0497 return 1;
0498
0499 return strcmp(a->name, b->name);
0500 }
0501
0502 static int bpf_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
0503 {
0504 struct bpf_lwt *a_bpf = bpf_lwt_lwtunnel(a);
0505 struct bpf_lwt *b_bpf = bpf_lwt_lwtunnel(b);
0506
0507 return bpf_lwt_prog_cmp(&a_bpf->in, &b_bpf->in) ||
0508 bpf_lwt_prog_cmp(&a_bpf->out, &b_bpf->out) ||
0509 bpf_lwt_prog_cmp(&a_bpf->xmit, &b_bpf->xmit);
0510 }
0511
0512 static const struct lwtunnel_encap_ops bpf_encap_ops = {
0513 .build_state = bpf_build_state,
0514 .destroy_state = bpf_destroy_state,
0515 .input = bpf_input,
0516 .output = bpf_output,
0517 .xmit = bpf_xmit,
0518 .fill_encap = bpf_fill_encap_info,
0519 .get_encap_size = bpf_encap_nlsize,
0520 .cmp_encap = bpf_encap_cmp,
0521 .owner = THIS_MODULE,
0522 };
0523
0524 static int handle_gso_type(struct sk_buff *skb, unsigned int gso_type,
0525 int encap_len)
0526 {
0527 struct skb_shared_info *shinfo = skb_shinfo(skb);
0528
0529 gso_type |= SKB_GSO_DODGY;
0530 shinfo->gso_type |= gso_type;
0531 skb_decrease_gso_size(shinfo, encap_len);
0532 shinfo->gso_segs = 0;
0533 return 0;
0534 }
0535
0536 static int handle_gso_encap(struct sk_buff *skb, bool ipv4, int encap_len)
0537 {
0538 int next_hdr_offset;
0539 void *next_hdr;
0540 __u8 protocol;
0541
0542
0543
0544
0545
0546 if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6)))
0547 return -ENOTSUPP;
0548
0549 if (ipv4) {
0550 protocol = ip_hdr(skb)->protocol;
0551 next_hdr_offset = sizeof(struct iphdr);
0552 next_hdr = skb_network_header(skb) + next_hdr_offset;
0553 } else {
0554 protocol = ipv6_hdr(skb)->nexthdr;
0555 next_hdr_offset = sizeof(struct ipv6hdr);
0556 next_hdr = skb_network_header(skb) + next_hdr_offset;
0557 }
0558
0559 switch (protocol) {
0560 case IPPROTO_GRE:
0561 next_hdr_offset += sizeof(struct gre_base_hdr);
0562 if (next_hdr_offset > encap_len)
0563 return -EINVAL;
0564
0565 if (((struct gre_base_hdr *)next_hdr)->flags & GRE_CSUM)
0566 return handle_gso_type(skb, SKB_GSO_GRE_CSUM,
0567 encap_len);
0568 return handle_gso_type(skb, SKB_GSO_GRE, encap_len);
0569
0570 case IPPROTO_UDP:
0571 next_hdr_offset += sizeof(struct udphdr);
0572 if (next_hdr_offset > encap_len)
0573 return -EINVAL;
0574
0575 if (((struct udphdr *)next_hdr)->check)
0576 return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL_CSUM,
0577 encap_len);
0578 return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL, encap_len);
0579
0580 case IPPROTO_IP:
0581 case IPPROTO_IPV6:
0582 if (ipv4)
0583 return handle_gso_type(skb, SKB_GSO_IPXIP4, encap_len);
0584 else
0585 return handle_gso_type(skb, SKB_GSO_IPXIP6, encap_len);
0586
0587 default:
0588 return -EPROTONOSUPPORT;
0589 }
0590 }
0591
0592 int bpf_lwt_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, bool ingress)
0593 {
0594 struct iphdr *iph;
0595 bool ipv4;
0596 int err;
0597
0598 if (unlikely(len < sizeof(struct iphdr) || len > LWT_BPF_MAX_HEADROOM))
0599 return -EINVAL;
0600
0601
0602 iph = (struct iphdr *)hdr;
0603 if (iph->version == 4) {
0604 ipv4 = true;
0605 if (unlikely(len < iph->ihl * 4))
0606 return -EINVAL;
0607 } else if (iph->version == 6) {
0608 ipv4 = false;
0609 if (unlikely(len < sizeof(struct ipv6hdr)))
0610 return -EINVAL;
0611 } else {
0612 return -EINVAL;
0613 }
0614
0615 if (ingress)
0616 err = skb_cow_head(skb, len + skb->mac_len);
0617 else
0618 err = skb_cow_head(skb,
0619 len + LL_RESERVED_SPACE(skb_dst(skb)->dev));
0620 if (unlikely(err))
0621 return err;
0622
0623
0624 skb_reset_inner_headers(skb);
0625 skb_reset_inner_mac_header(skb);
0626 skb_set_inner_protocol(skb, skb->protocol);
0627 skb->encapsulation = 1;
0628 skb_push(skb, len);
0629 if (ingress)
0630 skb_postpush_rcsum(skb, iph, len);
0631 skb_reset_network_header(skb);
0632 memcpy(skb_network_header(skb), hdr, len);
0633 bpf_compute_data_pointers(skb);
0634 skb_clear_hash(skb);
0635
0636 if (ipv4) {
0637 skb->protocol = htons(ETH_P_IP);
0638 iph = ip_hdr(skb);
0639
0640 if (!iph->check)
0641 iph->check = ip_fast_csum((unsigned char *)iph,
0642 iph->ihl);
0643 } else {
0644 skb->protocol = htons(ETH_P_IPV6);
0645 }
0646
0647 if (skb_is_gso(skb))
0648 return handle_gso_encap(skb, ipv4, len);
0649
0650 return 0;
0651 }
0652
0653 static int __init bpf_lwt_init(void)
0654 {
0655 return lwtunnel_encap_add_ops(&bpf_encap_ops, LWTUNNEL_ENCAP_BPF);
0656 }
0657
0658 subsys_initcall(bpf_lwt_init)