0001
0002
0003
0004
0005 #include <stddef.h>
0006 #include <stdbool.h>
0007 #include <string.h>
0008 #include <linux/bpf.h>
0009 #include <linux/if_ether.h>
0010 #include <linux/in.h>
0011 #include <linux/ip.h>
0012 #include <linux/ipv6.h>
0013 #include <linux/pkt_cls.h>
0014 #include <linux/tcp.h>
0015 #include <sys/socket.h>
0016 #include <bpf/bpf_helpers.h>
0017 #include <bpf/bpf_endian.h>
0018
0019
0020 #define PIN_GLOBAL_NS 2
0021
0022
0023 struct {
0024 __u32 type;
0025 __u32 size_key;
0026 __u32 size_value;
0027 __u32 max_elem;
0028 __u32 flags;
0029 __u32 id;
0030 __u32 pinning;
0031 } server_map SEC("maps") = {
0032 .type = BPF_MAP_TYPE_SOCKMAP,
0033 .size_key = sizeof(int),
0034 .size_value = sizeof(__u64),
0035 .max_elem = 1,
0036 .pinning = PIN_GLOBAL_NS,
0037 };
0038
0039 char _license[] SEC("license") = "GPL";
0040
0041
0042 static inline struct bpf_sock_tuple *
0043 get_tuple(struct __sk_buff *skb, bool *ipv4, bool *tcp)
0044 {
0045 void *data_end = (void *)(long)skb->data_end;
0046 void *data = (void *)(long)skb->data;
0047 struct bpf_sock_tuple *result;
0048 struct ethhdr *eth;
0049 __u64 tuple_len;
0050 __u8 proto = 0;
0051 __u64 ihl_len;
0052
0053 eth = (struct ethhdr *)(data);
0054 if (eth + 1 > data_end)
0055 return NULL;
0056
0057 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
0058 struct iphdr *iph = (struct iphdr *)(data + sizeof(*eth));
0059
0060 if (iph + 1 > data_end)
0061 return NULL;
0062 if (iph->ihl != 5)
0063
0064 return NULL;
0065 ihl_len = iph->ihl * 4;
0066 proto = iph->protocol;
0067 *ipv4 = true;
0068 result = (struct bpf_sock_tuple *)&iph->saddr;
0069 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
0070 struct ipv6hdr *ip6h = (struct ipv6hdr *)(data + sizeof(*eth));
0071
0072 if (ip6h + 1 > data_end)
0073 return NULL;
0074 ihl_len = sizeof(*ip6h);
0075 proto = ip6h->nexthdr;
0076 *ipv4 = false;
0077 result = (struct bpf_sock_tuple *)&ip6h->saddr;
0078 } else {
0079 return (struct bpf_sock_tuple *)data;
0080 }
0081
0082 if (proto != IPPROTO_TCP && proto != IPPROTO_UDP)
0083 return NULL;
0084
0085 *tcp = (proto == IPPROTO_TCP);
0086 return result;
0087 }
0088
0089 static inline int
0090 handle_udp(struct __sk_buff *skb, struct bpf_sock_tuple *tuple, bool ipv4)
0091 {
0092 struct bpf_sock *sk;
0093 const int zero = 0;
0094 size_t tuple_len;
0095 __be16 dport;
0096 int ret;
0097
0098 tuple_len = ipv4 ? sizeof(tuple->ipv4) : sizeof(tuple->ipv6);
0099 if ((void *)tuple + tuple_len > (void *)(long)skb->data_end)
0100 return TC_ACT_SHOT;
0101
0102 sk = bpf_sk_lookup_udp(skb, tuple, tuple_len, BPF_F_CURRENT_NETNS, 0);
0103 if (sk)
0104 goto assign;
0105
0106 dport = ipv4 ? tuple->ipv4.dport : tuple->ipv6.dport;
0107 if (dport != bpf_htons(4321))
0108 return TC_ACT_OK;
0109
0110 sk = bpf_map_lookup_elem(&server_map, &zero);
0111 if (!sk)
0112 return TC_ACT_SHOT;
0113
0114 assign:
0115 ret = bpf_sk_assign(skb, sk, 0);
0116 bpf_sk_release(sk);
0117 return ret;
0118 }
0119
0120 static inline int
0121 handle_tcp(struct __sk_buff *skb, struct bpf_sock_tuple *tuple, bool ipv4)
0122 {
0123 struct bpf_sock *sk;
0124 const int zero = 0;
0125 size_t tuple_len;
0126 __be16 dport;
0127 int ret;
0128
0129 tuple_len = ipv4 ? sizeof(tuple->ipv4) : sizeof(tuple->ipv6);
0130 if ((void *)tuple + tuple_len > (void *)(long)skb->data_end)
0131 return TC_ACT_SHOT;
0132
0133 sk = bpf_skc_lookup_tcp(skb, tuple, tuple_len, BPF_F_CURRENT_NETNS, 0);
0134 if (sk) {
0135 if (sk->state != BPF_TCP_LISTEN)
0136 goto assign;
0137 bpf_sk_release(sk);
0138 }
0139
0140 dport = ipv4 ? tuple->ipv4.dport : tuple->ipv6.dport;
0141 if (dport != bpf_htons(4321))
0142 return TC_ACT_OK;
0143
0144 sk = bpf_map_lookup_elem(&server_map, &zero);
0145 if (!sk)
0146 return TC_ACT_SHOT;
0147
0148 if (sk->state != BPF_TCP_LISTEN) {
0149 bpf_sk_release(sk);
0150 return TC_ACT_SHOT;
0151 }
0152
0153 assign:
0154 ret = bpf_sk_assign(skb, sk, 0);
0155 bpf_sk_release(sk);
0156 return ret;
0157 }
0158
0159 SEC("tc")
0160 int bpf_sk_assign_test(struct __sk_buff *skb)
0161 {
0162 struct bpf_sock_tuple *tuple;
0163 bool ipv4 = false;
0164 bool tcp = false;
0165 int tuple_len;
0166 int ret = 0;
0167
0168 tuple = get_tuple(skb, &ipv4, &tcp);
0169 if (!tuple)
0170 return TC_ACT_SHOT;
0171
0172
0173
0174
0175
0176
0177 if (tcp)
0178 ret = handle_tcp(skb, tuple, ipv4);
0179 else
0180 ret = handle_udp(skb, tuple, ipv4);
0181
0182 return ret == 0 ? TC_ACT_OK : TC_ACT_SHOT;
0183 }