0001
0002
0003 #include <stdbool.h>
0004 #include <stdint.h>
0005 #include <linux/stddef.h>
0006 #include <linux/if_ether.h>
0007 #include <linux/in.h>
0008 #include <linux/in6.h>
0009 #include <linux/ip.h>
0010 #include <linux/ipv6.h>
0011 #include <linux/tcp.h>
0012 #include <linux/udp.h>
0013 #include <linux/bpf.h>
0014 #include <linux/types.h>
0015 #include <bpf/bpf_endian.h>
0016 #include <bpf/bpf_helpers.h>
0017
0018 enum pkt_parse_err {
0019 NO_ERR,
0020 BAD_IP6_HDR,
0021 BAD_IP4GUE_HDR,
0022 BAD_IP6GUE_HDR,
0023 };
0024
0025 enum pkt_flag {
0026 TUNNEL = 0x1,
0027 TCP_SYN = 0x2,
0028 QUIC_INITIAL_FLAG = 0x4,
0029 TCP_ACK = 0x8,
0030 TCP_RST = 0x10
0031 };
0032
0033 struct v4_lpm_key {
0034 __u32 prefixlen;
0035 __u32 src;
0036 };
0037
0038 struct v4_lpm_val {
0039 struct v4_lpm_key key;
0040 __u8 val;
0041 };
0042
0043 struct {
0044 __uint(type, BPF_MAP_TYPE_HASH);
0045 __uint(max_entries, 16);
0046 __type(key, struct in6_addr);
0047 __type(value, bool);
0048 } v6_addr_map SEC(".maps");
0049
0050 struct {
0051 __uint(type, BPF_MAP_TYPE_HASH);
0052 __uint(max_entries, 16);
0053 __type(key, __u32);
0054 __type(value, bool);
0055 } v4_addr_map SEC(".maps");
0056
0057 struct {
0058 __uint(type, BPF_MAP_TYPE_LPM_TRIE);
0059 __uint(max_entries, 16);
0060 __uint(key_size, sizeof(struct v4_lpm_key));
0061 __uint(value_size, sizeof(struct v4_lpm_val));
0062 __uint(map_flags, BPF_F_NO_PREALLOC);
0063 } v4_lpm_val_map SEC(".maps");
0064
0065 struct {
0066 __uint(type, BPF_MAP_TYPE_ARRAY);
0067 __uint(max_entries, 16);
0068 __type(key, int);
0069 __type(value, __u8);
0070 } tcp_port_map SEC(".maps");
0071
0072 struct {
0073 __uint(type, BPF_MAP_TYPE_ARRAY);
0074 __uint(max_entries, 16);
0075 __type(key, int);
0076 __type(value, __u16);
0077 } udp_port_map SEC(".maps");
0078
0079 enum ip_type { V4 = 1, V6 = 2 };
0080
0081 struct fw_match_info {
0082 __u8 v4_src_ip_match;
0083 __u8 v6_src_ip_match;
0084 __u8 v4_src_prefix_match;
0085 __u8 v4_dst_prefix_match;
0086 __u8 tcp_dp_match;
0087 __u16 udp_sp_match;
0088 __u16 udp_dp_match;
0089 bool is_tcp;
0090 bool is_tcp_syn;
0091 };
0092
0093 struct pkt_info {
0094 enum ip_type type;
0095 union {
0096 struct iphdr *ipv4;
0097 struct ipv6hdr *ipv6;
0098 } ip;
0099 int sport;
0100 int dport;
0101 __u16 trans_hdr_offset;
0102 __u8 proto;
0103 __u8 flags;
0104 };
0105
0106 static __always_inline struct ethhdr *parse_ethhdr(void *data, void *data_end)
0107 {
0108 struct ethhdr *eth = data;
0109
0110 if (eth + 1 > data_end)
0111 return NULL;
0112
0113 return eth;
0114 }
0115
0116 static __always_inline __u8 filter_ipv6_addr(const struct in6_addr *ipv6addr)
0117 {
0118 __u8 *leaf;
0119
0120 leaf = bpf_map_lookup_elem(&v6_addr_map, ipv6addr);
0121
0122 return leaf ? *leaf : 0;
0123 }
0124
0125 static __always_inline __u8 filter_ipv4_addr(const __u32 ipaddr)
0126 {
0127 __u8 *leaf;
0128
0129 leaf = bpf_map_lookup_elem(&v4_addr_map, &ipaddr);
0130
0131 return leaf ? *leaf : 0;
0132 }
0133
0134 static __always_inline __u8 filter_ipv4_lpm(const __u32 ipaddr)
0135 {
0136 struct v4_lpm_key v4_key = {};
0137 struct v4_lpm_val *lpm_val;
0138
0139 v4_key.src = ipaddr;
0140 v4_key.prefixlen = 32;
0141
0142 lpm_val = bpf_map_lookup_elem(&v4_lpm_val_map, &v4_key);
0143
0144 return lpm_val ? lpm_val->val : 0;
0145 }
0146
0147
0148 static __always_inline void
0149 filter_src_dst_ip(struct pkt_info* info, struct fw_match_info* match_info)
0150 {
0151 if (info->type == V6) {
0152 match_info->v6_src_ip_match =
0153 filter_ipv6_addr(&info->ip.ipv6->saddr);
0154 } else if (info->type == V4) {
0155 match_info->v4_src_ip_match =
0156 filter_ipv4_addr(info->ip.ipv4->saddr);
0157 match_info->v4_src_prefix_match =
0158 filter_ipv4_lpm(info->ip.ipv4->saddr);
0159 match_info->v4_dst_prefix_match =
0160 filter_ipv4_lpm(info->ip.ipv4->daddr);
0161 }
0162 }
0163
0164 static __always_inline void *
0165 get_transport_hdr(__u16 offset, void *data, void *data_end)
0166 {
0167 if (offset > 255 || data + offset > data_end)
0168 return NULL;
0169
0170 return data + offset;
0171 }
0172
0173 static __always_inline bool tcphdr_only_contains_flag(struct tcphdr *tcp,
0174 __u32 FLAG)
0175 {
0176 return (tcp_flag_word(tcp) &
0177 (TCP_FLAG_ACK | TCP_FLAG_RST | TCP_FLAG_SYN | TCP_FLAG_FIN)) == FLAG;
0178 }
0179
0180 static __always_inline void set_tcp_flags(struct pkt_info *info,
0181 struct tcphdr *tcp) {
0182 if (tcphdr_only_contains_flag(tcp, TCP_FLAG_SYN))
0183 info->flags |= TCP_SYN;
0184 else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_ACK))
0185 info->flags |= TCP_ACK;
0186 else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_RST))
0187 info->flags |= TCP_RST;
0188 }
0189
0190 static __always_inline bool
0191 parse_tcp(struct pkt_info *info, void *transport_hdr, void *data_end)
0192 {
0193 struct tcphdr *tcp = transport_hdr;
0194
0195 if (tcp + 1 > data_end)
0196 return false;
0197
0198 info->sport = bpf_ntohs(tcp->source);
0199 info->dport = bpf_ntohs(tcp->dest);
0200 set_tcp_flags(info, tcp);
0201
0202 return true;
0203 }
0204
0205 static __always_inline bool
0206 parse_udp(struct pkt_info *info, void *transport_hdr, void *data_end)
0207 {
0208 struct udphdr *udp = transport_hdr;
0209
0210 if (udp + 1 > data_end)
0211 return false;
0212
0213 info->sport = bpf_ntohs(udp->source);
0214 info->dport = bpf_ntohs(udp->dest);
0215
0216 return true;
0217 }
0218
0219 static __always_inline __u8 filter_tcp_port(int port)
0220 {
0221 __u8 *leaf = bpf_map_lookup_elem(&tcp_port_map, &port);
0222
0223 return leaf ? *leaf : 0;
0224 }
0225
0226 static __always_inline __u16 filter_udp_port(int port)
0227 {
0228 __u16 *leaf = bpf_map_lookup_elem(&udp_port_map, &port);
0229
0230 return leaf ? *leaf : 0;
0231 }
0232
0233 static __always_inline bool
0234 filter_transport_hdr(void *transport_hdr, void *data_end,
0235 struct pkt_info *info, struct fw_match_info *match_info)
0236 {
0237 if (info->proto == IPPROTO_TCP) {
0238 if (!parse_tcp(info, transport_hdr, data_end))
0239 return false;
0240
0241 match_info->is_tcp = true;
0242 match_info->is_tcp_syn = (info->flags & TCP_SYN) > 0;
0243
0244 match_info->tcp_dp_match = filter_tcp_port(info->dport);
0245 } else if (info->proto == IPPROTO_UDP) {
0246 if (!parse_udp(info, transport_hdr, data_end))
0247 return false;
0248
0249 match_info->udp_dp_match = filter_udp_port(info->dport);
0250 match_info->udp_sp_match = filter_udp_port(info->sport);
0251 }
0252
0253 return true;
0254 }
0255
0256 static __always_inline __u8
0257 parse_gue_v6(struct pkt_info *info, struct ipv6hdr *ip6h, void *data_end)
0258 {
0259 struct udphdr *udp = (struct udphdr *)(ip6h + 1);
0260 void *encap_data = udp + 1;
0261
0262 if (udp + 1 > data_end)
0263 return BAD_IP6_HDR;
0264
0265 if (udp->dest != bpf_htons(6666))
0266 return NO_ERR;
0267
0268 info->flags |= TUNNEL;
0269
0270 if (encap_data + 1 > data_end)
0271 return BAD_IP6GUE_HDR;
0272
0273 if (*(__u8 *)encap_data & 0x30) {
0274 struct ipv6hdr *inner_ip6h = encap_data;
0275
0276 if (inner_ip6h + 1 > data_end)
0277 return BAD_IP6GUE_HDR;
0278
0279 info->type = V6;
0280 info->proto = inner_ip6h->nexthdr;
0281 info->ip.ipv6 = inner_ip6h;
0282 info->trans_hdr_offset += sizeof(struct ipv6hdr) + sizeof(struct udphdr);
0283 } else {
0284 struct iphdr *inner_ip4h = encap_data;
0285
0286 if (inner_ip4h + 1 > data_end)
0287 return BAD_IP6GUE_HDR;
0288
0289 info->type = V4;
0290 info->proto = inner_ip4h->protocol;
0291 info->ip.ipv4 = inner_ip4h;
0292 info->trans_hdr_offset += sizeof(struct iphdr) + sizeof(struct udphdr);
0293 }
0294
0295 return NO_ERR;
0296 }
0297
0298 static __always_inline __u8 parse_ipv6_gue(struct pkt_info *info,
0299 void *data, void *data_end)
0300 {
0301 struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
0302
0303 if (ip6h + 1 > data_end)
0304 return BAD_IP6_HDR;
0305
0306 info->proto = ip6h->nexthdr;
0307 info->ip.ipv6 = ip6h;
0308 info->type = V6;
0309 info->trans_hdr_offset = sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
0310
0311 if (info->proto == IPPROTO_UDP)
0312 return parse_gue_v6(info, ip6h, data_end);
0313
0314 return NO_ERR;
0315 }
0316
0317 SEC("xdp")
0318 int edgewall(struct xdp_md *ctx)
0319 {
0320 void *data_end = (void *)(long)(ctx->data_end);
0321 void *data = (void *)(long)(ctx->data);
0322 struct fw_match_info match_info = {};
0323 struct pkt_info info = {};
0324 __u8 parse_err = NO_ERR;
0325 void *transport_hdr;
0326 struct ethhdr *eth;
0327 bool filter_res;
0328 __u32 proto;
0329
0330 eth = parse_ethhdr(data, data_end);
0331 if (!eth)
0332 return XDP_DROP;
0333
0334 proto = eth->h_proto;
0335 if (proto != bpf_htons(ETH_P_IPV6))
0336 return XDP_DROP;
0337
0338 if (parse_ipv6_gue(&info, data, data_end))
0339 return XDP_DROP;
0340
0341 if (info.proto == IPPROTO_ICMPV6)
0342 return XDP_PASS;
0343
0344 if (info.proto != IPPROTO_TCP && info.proto != IPPROTO_UDP)
0345 return XDP_DROP;
0346
0347 filter_src_dst_ip(&info, &match_info);
0348
0349 transport_hdr = get_transport_hdr(info.trans_hdr_offset, data,
0350 data_end);
0351 if (!transport_hdr)
0352 return XDP_DROP;
0353
0354 filter_res = filter_transport_hdr(transport_hdr, data_end,
0355 &info, &match_info);
0356 if (!filter_res)
0357 return XDP_DROP;
0358
0359 if (match_info.is_tcp && !match_info.is_tcp_syn)
0360 return XDP_PASS;
0361
0362 return XDP_DROP;
0363 }
0364
0365 char LICENSE[] SEC("license") = "GPL";