Back to home page

OSCL-LXR

 
 

    


0001 /* Copyright (c) 2017 Facebook
0002  *
0003  * This program is free software; you can redistribute it and/or
0004  * modify it under the terms of version 2 of the GNU General Public
0005  * License as published by the Free Software Foundation.
0006  */
0007 #include <stddef.h>
0008 #include <stdbool.h>
0009 #include <string.h>
0010 #include <linux/pkt_cls.h>
0011 #include <linux/bpf.h>
0012 #include <linux/in.h>
0013 #include <linux/if_ether.h>
0014 #include <linux/ip.h>
0015 #include <linux/ipv6.h>
0016 #include <linux/icmp.h>
0017 #include <linux/icmpv6.h>
0018 #include <linux/tcp.h>
0019 #include <linux/udp.h>
0020 #include <bpf/bpf_helpers.h>
0021 #include "test_iptunnel_common.h"
0022 #include <bpf/bpf_endian.h>
0023 
0024 static inline __u32 rol32(__u32 word, unsigned int shift)
0025 {
0026     return (word << shift) | (word >> ((-shift) & 31));
0027 }
0028 
0029 /* copy paste of jhash from kernel sources to make sure llvm
0030  * can compile it into valid sequence of bpf instructions
0031  */
0032 #define __jhash_mix(a, b, c)            \
0033 {                       \
0034     a -= c;  a ^= rol32(c, 4);  c += b; \
0035     b -= a;  b ^= rol32(a, 6);  a += c; \
0036     c -= b;  c ^= rol32(b, 8);  b += a; \
0037     a -= c;  a ^= rol32(c, 16); c += b; \
0038     b -= a;  b ^= rol32(a, 19); a += c; \
0039     c -= b;  c ^= rol32(b, 4);  b += a; \
0040 }
0041 
0042 #define __jhash_final(a, b, c)          \
0043 {                       \
0044     c ^= b; c -= rol32(b, 14);      \
0045     a ^= c; a -= rol32(c, 11);      \
0046     b ^= a; b -= rol32(a, 25);      \
0047     c ^= b; c -= rol32(b, 16);      \
0048     a ^= c; a -= rol32(c, 4);       \
0049     b ^= a; b -= rol32(a, 14);      \
0050     c ^= b; c -= rol32(b, 24);      \
0051 }
0052 
0053 #define JHASH_INITVAL       0xdeadbeef
0054 
0055 typedef unsigned int u32;
0056 
0057 static inline u32 jhash(const void *key, u32 length, u32 initval)
0058 {
0059     u32 a, b, c;
0060     const unsigned char *k = key;
0061 
0062     a = b = c = JHASH_INITVAL + length + initval;
0063 
0064     while (length > 12) {
0065         a += *(u32 *)(k);
0066         b += *(u32 *)(k + 4);
0067         c += *(u32 *)(k + 8);
0068         __jhash_mix(a, b, c);
0069         length -= 12;
0070         k += 12;
0071     }
0072     switch (length) {
0073     case 12: c += (u32)k[11]<<24;
0074     case 11: c += (u32)k[10]<<16;
0075     case 10: c += (u32)k[9]<<8;
0076     case 9:  c += k[8];
0077     case 8:  b += (u32)k[7]<<24;
0078     case 7:  b += (u32)k[6]<<16;
0079     case 6:  b += (u32)k[5]<<8;
0080     case 5:  b += k[4];
0081     case 4:  a += (u32)k[3]<<24;
0082     case 3:  a += (u32)k[2]<<16;
0083     case 2:  a += (u32)k[1]<<8;
0084     case 1:  a += k[0];
0085          __jhash_final(a, b, c);
0086     case 0: /* Nothing left to add */
0087         break;
0088     }
0089 
0090     return c;
0091 }
0092 
0093 static inline u32 __jhash_nwords(u32 a, u32 b, u32 c, u32 initval)
0094 {
0095     a += initval;
0096     b += initval;
0097     c += initval;
0098     __jhash_final(a, b, c);
0099     return c;
0100 }
0101 
0102 static inline u32 jhash_2words(u32 a, u32 b, u32 initval)
0103 {
0104     return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
0105 }
0106 
0107 #define PCKT_FRAGMENTED 65343
0108 #define IPV4_HDR_LEN_NO_OPT 20
0109 #define IPV4_PLUS_ICMP_HDR 28
0110 #define IPV6_PLUS_ICMP_HDR 48
0111 #define RING_SIZE 2
0112 #define MAX_VIPS 12
0113 #define MAX_REALS 5
0114 #define CTL_MAP_SIZE 16
0115 #define CH_RINGS_SIZE (MAX_VIPS * RING_SIZE)
0116 #define F_IPV6 (1 << 0)
0117 #define F_HASH_NO_SRC_PORT (1 << 0)
0118 #define F_ICMP (1 << 0)
0119 #define F_SYN_SET (1 << 1)
0120 
0121 struct packet_description {
0122     union {
0123         __be32 src;
0124         __be32 srcv6[4];
0125     };
0126     union {
0127         __be32 dst;
0128         __be32 dstv6[4];
0129     };
0130     union {
0131         __u32 ports;
0132         __u16 port16[2];
0133     };
0134     __u8 proto;
0135     __u8 flags;
0136 };
0137 
0138 struct ctl_value {
0139     union {
0140         __u64 value;
0141         __u32 ifindex;
0142         __u8 mac[6];
0143     };
0144 };
0145 
0146 struct vip_meta {
0147     __u32 flags;
0148     __u32 vip_num;
0149 };
0150 
0151 struct real_definition {
0152     union {
0153         __be32 dst;
0154         __be32 dstv6[4];
0155     };
0156     __u8 flags;
0157 };
0158 
0159 struct vip_stats {
0160     __u64 bytes;
0161     __u64 pkts;
0162 };
0163 
0164 struct eth_hdr {
0165     unsigned char eth_dest[ETH_ALEN];
0166     unsigned char eth_source[ETH_ALEN];
0167     unsigned short eth_proto;
0168 };
0169 
0170 struct {
0171     __uint(type, BPF_MAP_TYPE_HASH);
0172     __uint(max_entries, MAX_VIPS);
0173     __type(key, struct vip);
0174     __type(value, struct vip_meta);
0175 } vip_map SEC(".maps");
0176 
0177 struct {
0178     __uint(type, BPF_MAP_TYPE_ARRAY);
0179     __uint(max_entries, CH_RINGS_SIZE);
0180     __type(key, __u32);
0181     __type(value, __u32);
0182 } ch_rings SEC(".maps");
0183 
0184 struct {
0185     __uint(type, BPF_MAP_TYPE_ARRAY);
0186     __uint(max_entries, MAX_REALS);
0187     __type(key, __u32);
0188     __type(value, struct real_definition);
0189 } reals SEC(".maps");
0190 
0191 struct {
0192     __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
0193     __uint(max_entries, MAX_VIPS);
0194     __type(key, __u32);
0195     __type(value, struct vip_stats);
0196 } stats SEC(".maps");
0197 
0198 struct {
0199     __uint(type, BPF_MAP_TYPE_ARRAY);
0200     __uint(max_entries, CTL_MAP_SIZE);
0201     __type(key, __u32);
0202     __type(value, struct ctl_value);
0203 } ctl_array SEC(".maps");
0204 
0205 static __always_inline __u32 get_packet_hash(struct packet_description *pckt,
0206                          bool ipv6)
0207 {
0208     if (ipv6)
0209         return jhash_2words(jhash(pckt->srcv6, 16, MAX_VIPS),
0210                     pckt->ports, CH_RINGS_SIZE);
0211     else
0212         return jhash_2words(pckt->src, pckt->ports, CH_RINGS_SIZE);
0213 }
0214 
0215 static __always_inline bool get_packet_dst(struct real_definition **real,
0216                        struct packet_description *pckt,
0217                        struct vip_meta *vip_info,
0218                        bool is_ipv6)
0219 {
0220     __u32 hash = get_packet_hash(pckt, is_ipv6) % RING_SIZE;
0221     __u32 key = RING_SIZE * vip_info->vip_num + hash;
0222     __u32 *real_pos;
0223 
0224     real_pos = bpf_map_lookup_elem(&ch_rings, &key);
0225     if (!real_pos)
0226         return false;
0227     key = *real_pos;
0228     *real = bpf_map_lookup_elem(&reals, &key);
0229     if (!(*real))
0230         return false;
0231     return true;
0232 }
0233 
0234 static __always_inline int parse_icmpv6(void *data, void *data_end, __u64 off,
0235                     struct packet_description *pckt)
0236 {
0237     struct icmp6hdr *icmp_hdr;
0238     struct ipv6hdr *ip6h;
0239 
0240     icmp_hdr = data + off;
0241     if (icmp_hdr + 1 > data_end)
0242         return TC_ACT_SHOT;
0243     if (icmp_hdr->icmp6_type != ICMPV6_PKT_TOOBIG)
0244         return TC_ACT_OK;
0245     off += sizeof(struct icmp6hdr);
0246     ip6h = data + off;
0247     if (ip6h + 1 > data_end)
0248         return TC_ACT_SHOT;
0249     pckt->proto = ip6h->nexthdr;
0250     pckt->flags |= F_ICMP;
0251     memcpy(pckt->srcv6, ip6h->daddr.s6_addr32, 16);
0252     memcpy(pckt->dstv6, ip6h->saddr.s6_addr32, 16);
0253     return TC_ACT_UNSPEC;
0254 }
0255 
0256 static __always_inline int parse_icmp(void *data, void *data_end, __u64 off,
0257                       struct packet_description *pckt)
0258 {
0259     struct icmphdr *icmp_hdr;
0260     struct iphdr *iph;
0261 
0262     icmp_hdr = data + off;
0263     if (icmp_hdr + 1 > data_end)
0264         return TC_ACT_SHOT;
0265     if (icmp_hdr->type != ICMP_DEST_UNREACH ||
0266         icmp_hdr->code != ICMP_FRAG_NEEDED)
0267         return TC_ACT_OK;
0268     off += sizeof(struct icmphdr);
0269     iph = data + off;
0270     if (iph + 1 > data_end)
0271         return TC_ACT_SHOT;
0272     if (iph->ihl != 5)
0273         return TC_ACT_SHOT;
0274     pckt->proto = iph->protocol;
0275     pckt->flags |= F_ICMP;
0276     pckt->src = iph->daddr;
0277     pckt->dst = iph->saddr;
0278     return TC_ACT_UNSPEC;
0279 }
0280 
0281 static __always_inline bool parse_udp(void *data, __u64 off, void *data_end,
0282                       struct packet_description *pckt)
0283 {
0284     struct udphdr *udp;
0285     udp = data + off;
0286 
0287     if (udp + 1 > data_end)
0288         return false;
0289 
0290     if (!(pckt->flags & F_ICMP)) {
0291         pckt->port16[0] = udp->source;
0292         pckt->port16[1] = udp->dest;
0293     } else {
0294         pckt->port16[0] = udp->dest;
0295         pckt->port16[1] = udp->source;
0296     }
0297     return true;
0298 }
0299 
0300 static __always_inline bool parse_tcp(void *data, __u64 off, void *data_end,
0301                       struct packet_description *pckt)
0302 {
0303     struct tcphdr *tcp;
0304 
0305     tcp = data + off;
0306     if (tcp + 1 > data_end)
0307         return false;
0308 
0309     if (tcp->syn)
0310         pckt->flags |= F_SYN_SET;
0311 
0312     if (!(pckt->flags & F_ICMP)) {
0313         pckt->port16[0] = tcp->source;
0314         pckt->port16[1] = tcp->dest;
0315     } else {
0316         pckt->port16[0] = tcp->dest;
0317         pckt->port16[1] = tcp->source;
0318     }
0319     return true;
0320 }
0321 
0322 static __always_inline int process_packet(void *data, __u64 off, void *data_end,
0323                       bool is_ipv6, struct __sk_buff *skb)
0324 {
0325     void *pkt_start = (void *)(long)skb->data;
0326     struct packet_description pckt = {};
0327     struct eth_hdr *eth = pkt_start;
0328     struct bpf_tunnel_key tkey = {};
0329     struct vip_stats *data_stats;
0330     struct real_definition *dst;
0331     struct vip_meta *vip_info;
0332     struct ctl_value *cval;
0333     __u32 v4_intf_pos = 1;
0334     __u32 v6_intf_pos = 2;
0335     struct ipv6hdr *ip6h;
0336     struct vip vip = {};
0337     struct iphdr *iph;
0338     int tun_flag = 0;
0339     __u16 pkt_bytes;
0340     __u64 iph_len;
0341     __u32 ifindex;
0342     __u8 protocol;
0343     __u32 vip_num;
0344     int action;
0345 
0346     tkey.tunnel_ttl = 64;
0347     if (is_ipv6) {
0348         ip6h = data + off;
0349         if (ip6h + 1 > data_end)
0350             return TC_ACT_SHOT;
0351 
0352         iph_len = sizeof(struct ipv6hdr);
0353         protocol = ip6h->nexthdr;
0354         pckt.proto = protocol;
0355         pkt_bytes = bpf_ntohs(ip6h->payload_len);
0356         off += iph_len;
0357         if (protocol == IPPROTO_FRAGMENT) {
0358             return TC_ACT_SHOT;
0359         } else if (protocol == IPPROTO_ICMPV6) {
0360             action = parse_icmpv6(data, data_end, off, &pckt);
0361             if (action >= 0)
0362                 return action;
0363             off += IPV6_PLUS_ICMP_HDR;
0364         } else {
0365             memcpy(pckt.srcv6, ip6h->saddr.s6_addr32, 16);
0366             memcpy(pckt.dstv6, ip6h->daddr.s6_addr32, 16);
0367         }
0368     } else {
0369         iph = data + off;
0370         if (iph + 1 > data_end)
0371             return TC_ACT_SHOT;
0372         if (iph->ihl != 5)
0373             return TC_ACT_SHOT;
0374 
0375         protocol = iph->protocol;
0376         pckt.proto = protocol;
0377         pkt_bytes = bpf_ntohs(iph->tot_len);
0378         off += IPV4_HDR_LEN_NO_OPT;
0379 
0380         if (iph->frag_off & PCKT_FRAGMENTED)
0381             return TC_ACT_SHOT;
0382         if (protocol == IPPROTO_ICMP) {
0383             action = parse_icmp(data, data_end, off, &pckt);
0384             if (action >= 0)
0385                 return action;
0386             off += IPV4_PLUS_ICMP_HDR;
0387         } else {
0388             pckt.src = iph->saddr;
0389             pckt.dst = iph->daddr;
0390         }
0391     }
0392     protocol = pckt.proto;
0393 
0394     if (protocol == IPPROTO_TCP) {
0395         if (!parse_tcp(data, off, data_end, &pckt))
0396             return TC_ACT_SHOT;
0397     } else if (protocol == IPPROTO_UDP) {
0398         if (!parse_udp(data, off, data_end, &pckt))
0399             return TC_ACT_SHOT;
0400     } else {
0401         return TC_ACT_SHOT;
0402     }
0403 
0404     if (is_ipv6)
0405         memcpy(vip.daddr.v6, pckt.dstv6, 16);
0406     else
0407         vip.daddr.v4 = pckt.dst;
0408 
0409     vip.dport = pckt.port16[1];
0410     vip.protocol = pckt.proto;
0411     vip_info = bpf_map_lookup_elem(&vip_map, &vip);
0412     if (!vip_info) {
0413         vip.dport = 0;
0414         vip_info = bpf_map_lookup_elem(&vip_map, &vip);
0415         if (!vip_info)
0416             return TC_ACT_SHOT;
0417         pckt.port16[1] = 0;
0418     }
0419 
0420     if (vip_info->flags & F_HASH_NO_SRC_PORT)
0421         pckt.port16[0] = 0;
0422 
0423     if (!get_packet_dst(&dst, &pckt, vip_info, is_ipv6))
0424         return TC_ACT_SHOT;
0425 
0426     if (dst->flags & F_IPV6) {
0427         cval = bpf_map_lookup_elem(&ctl_array, &v6_intf_pos);
0428         if (!cval)
0429             return TC_ACT_SHOT;
0430         ifindex = cval->ifindex;
0431         memcpy(tkey.remote_ipv6, dst->dstv6, 16);
0432         tun_flag = BPF_F_TUNINFO_IPV6;
0433     } else {
0434         cval = bpf_map_lookup_elem(&ctl_array, &v4_intf_pos);
0435         if (!cval)
0436             return TC_ACT_SHOT;
0437         ifindex = cval->ifindex;
0438         tkey.remote_ipv4 = dst->dst;
0439     }
0440     vip_num = vip_info->vip_num;
0441     data_stats = bpf_map_lookup_elem(&stats, &vip_num);
0442     if (!data_stats)
0443         return TC_ACT_SHOT;
0444     data_stats->pkts++;
0445     data_stats->bytes += pkt_bytes;
0446     bpf_skb_set_tunnel_key(skb, &tkey, sizeof(tkey), tun_flag);
0447     *(u32 *)eth->eth_dest = tkey.remote_ipv4;
0448     return bpf_redirect(ifindex, 0);
0449 }
0450 
0451 SEC("tc")
0452 int balancer_ingress(struct __sk_buff *ctx)
0453 {
0454     void *data_end = (void *)(long)ctx->data_end;
0455     void *data = (void *)(long)ctx->data;
0456     struct eth_hdr *eth = data;
0457     __u32 eth_proto;
0458     __u32 nh_off;
0459 
0460     nh_off = sizeof(struct eth_hdr);
0461     if (data + nh_off > data_end)
0462         return TC_ACT_SHOT;
0463     eth_proto = eth->eth_proto;
0464     if (eth_proto == bpf_htons(ETH_P_IP))
0465         return process_packet(data, nh_off, data_end, false, ctx);
0466     else if (eth_proto == bpf_htons(ETH_P_IPV6))
0467         return process_packet(data, nh_off, data_end, true, ctx);
0468     else
0469         return TC_ACT_SHOT;
0470 }
0471 char _license[] SEC("license") = "GPL";