Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Copyright (c) 2018 Facebook */
0003 
0004 #include <stdlib.h>
0005 #include <linux/in.h>
0006 #include <linux/ip.h>
0007 #include <linux/ipv6.h>
0008 #include <linux/tcp.h>
0009 #include <linux/udp.h>
0010 #include <linux/bpf.h>
0011 #include <linux/types.h>
0012 #include <linux/if_ether.h>
0013 
0014 #include <bpf/bpf_endian.h>
0015 #include <bpf/bpf_helpers.h>
0016 #include "test_select_reuseport_common.h"
0017 
0018 #ifndef offsetof
0019 #define offsetof(TYPE, MEMBER) ((size_t) &((TYPE *)0)->MEMBER)
0020 #endif
0021 
0022 struct {
0023     __uint(type, BPF_MAP_TYPE_ARRAY_OF_MAPS);
0024     __uint(max_entries, 1);
0025     __type(key, __u32);
0026     __type(value, __u32);
0027 } outer_map SEC(".maps");
0028 
0029 struct {
0030     __uint(type, BPF_MAP_TYPE_ARRAY);
0031     __uint(max_entries, NR_RESULTS);
0032     __type(key, __u32);
0033     __type(value, __u32);
0034 } result_map SEC(".maps");
0035 
0036 struct {
0037     __uint(type, BPF_MAP_TYPE_ARRAY);
0038     __uint(max_entries, 1);
0039     __type(key, __u32);
0040     __type(value, int);
0041 } tmp_index_ovr_map SEC(".maps");
0042 
0043 struct {
0044     __uint(type, BPF_MAP_TYPE_ARRAY);
0045     __uint(max_entries, 1);
0046     __type(key, __u32);
0047     __type(value, __u32);
0048 } linum_map SEC(".maps");
0049 
0050 struct {
0051     __uint(type, BPF_MAP_TYPE_ARRAY);
0052     __uint(max_entries, 1);
0053     __type(key, __u32);
0054     __type(value, struct data_check);
0055 } data_check_map SEC(".maps");
0056 
0057 #define GOTO_DONE(_result) ({           \
0058     result = (_result);         \
0059     linum = __LINE__;           \
0060     goto done;              \
0061 })
0062 
0063 SEC("sk_reuseport")
0064 int _select_by_skb_data(struct sk_reuseport_md *reuse_md)
0065 {
0066     __u32 linum, index = 0, flags = 0, index_zero = 0;
0067     __u32 *result_cnt, *linum_value;
0068     struct data_check data_check = {};
0069     struct cmd *cmd, cmd_copy;
0070     void *data, *data_end;
0071     void *reuseport_array;
0072     enum result result;
0073     int *index_ovr;
0074     int err;
0075 
0076     data = reuse_md->data;
0077     data_end = reuse_md->data_end;
0078     data_check.len = reuse_md->len;
0079     data_check.eth_protocol = reuse_md->eth_protocol;
0080     data_check.ip_protocol = reuse_md->ip_protocol;
0081     data_check.hash = reuse_md->hash;
0082     data_check.bind_inany = reuse_md->bind_inany;
0083     if (data_check.eth_protocol == bpf_htons(ETH_P_IP)) {
0084         if (bpf_skb_load_bytes_relative(reuse_md,
0085                         offsetof(struct iphdr, saddr),
0086                         data_check.skb_addrs, 8,
0087                         BPF_HDR_START_NET))
0088             GOTO_DONE(DROP_MISC);
0089     } else {
0090         if (bpf_skb_load_bytes_relative(reuse_md,
0091                         offsetof(struct ipv6hdr, saddr),
0092                         data_check.skb_addrs, 32,
0093                         BPF_HDR_START_NET))
0094             GOTO_DONE(DROP_MISC);
0095     }
0096 
0097     /*
0098      * The ip_protocol could be a compile time decision
0099      * if the bpf_prog.o is dedicated to either TCP or
0100      * UDP.
0101      *
0102      * Otherwise, reuse_md->ip_protocol or
0103      * the protocol field in the iphdr can be used.
0104      */
0105     if (data_check.ip_protocol == IPPROTO_TCP) {
0106         struct tcphdr *th = data;
0107 
0108         if (th + 1 > data_end)
0109             GOTO_DONE(DROP_MISC);
0110 
0111         data_check.skb_ports[0] = th->source;
0112         data_check.skb_ports[1] = th->dest;
0113 
0114         if (th->fin)
0115             /* The connection is being torn down at the end of a
0116              * test. It can't contain a cmd, so return early.
0117              */
0118             return SK_PASS;
0119 
0120         if ((th->doff << 2) + sizeof(*cmd) > data_check.len)
0121             GOTO_DONE(DROP_ERR_SKB_DATA);
0122         if (bpf_skb_load_bytes(reuse_md, th->doff << 2, &cmd_copy,
0123                        sizeof(cmd_copy)))
0124             GOTO_DONE(DROP_MISC);
0125         cmd = &cmd_copy;
0126     } else if (data_check.ip_protocol == IPPROTO_UDP) {
0127         struct udphdr *uh = data;
0128 
0129         if (uh + 1 > data_end)
0130             GOTO_DONE(DROP_MISC);
0131 
0132         data_check.skb_ports[0] = uh->source;
0133         data_check.skb_ports[1] = uh->dest;
0134 
0135         if (sizeof(struct udphdr) + sizeof(*cmd) > data_check.len)
0136             GOTO_DONE(DROP_ERR_SKB_DATA);
0137         if (data + sizeof(struct udphdr) + sizeof(*cmd) > data_end) {
0138             if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr),
0139                            &cmd_copy, sizeof(cmd_copy)))
0140                 GOTO_DONE(DROP_MISC);
0141             cmd = &cmd_copy;
0142         } else {
0143             cmd = data + sizeof(struct udphdr);
0144         }
0145     } else {
0146         GOTO_DONE(DROP_MISC);
0147     }
0148 
0149     reuseport_array = bpf_map_lookup_elem(&outer_map, &index_zero);
0150     if (!reuseport_array)
0151         GOTO_DONE(DROP_ERR_INNER_MAP);
0152 
0153     index = cmd->reuseport_index;
0154     index_ovr = bpf_map_lookup_elem(&tmp_index_ovr_map, &index_zero);
0155     if (!index_ovr)
0156         GOTO_DONE(DROP_MISC);
0157 
0158     if (*index_ovr != -1) {
0159         index = *index_ovr;
0160         *index_ovr = -1;
0161     }
0162     err = bpf_sk_select_reuseport(reuse_md, reuseport_array, &index,
0163                       flags);
0164     if (!err)
0165         GOTO_DONE(PASS);
0166 
0167     if (cmd->pass_on_failure)
0168         GOTO_DONE(PASS_ERR_SK_SELECT_REUSEPORT);
0169     else
0170         GOTO_DONE(DROP_ERR_SK_SELECT_REUSEPORT);
0171 
0172 done:
0173     result_cnt = bpf_map_lookup_elem(&result_map, &result);
0174     if (!result_cnt)
0175         return SK_DROP;
0176 
0177     bpf_map_update_elem(&linum_map, &index_zero, &linum, BPF_ANY);
0178     bpf_map_update_elem(&data_check_map, &index_zero, &data_check, BPF_ANY);
0179 
0180     (*result_cnt)++;
0181     return result < PASS ? SK_DROP : SK_PASS;
0182 }
0183 
0184 char _license[] SEC("license") = "GPL";