0001
0002
0003
0004
0005 #include <test_progs.h>
0006 #include "cgroup_helpers.h"
0007 #include "network_helpers.h"
0008 #include "mptcp_sock.skel.h"
0009
0010 #ifndef TCP_CA_NAME_MAX
0011 #define TCP_CA_NAME_MAX 16
0012 #endif
0013
0014 struct mptcp_storage {
0015 __u32 invoked;
0016 __u32 is_mptcp;
0017 struct sock *sk;
0018 __u32 token;
0019 struct sock *first;
0020 char ca_name[TCP_CA_NAME_MAX];
0021 };
0022
0023 static int verify_tsk(int map_fd, int client_fd)
0024 {
0025 int err, cfd = client_fd;
0026 struct mptcp_storage val;
0027
0028 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
0029 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
0030 return err;
0031
0032 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
0033 err++;
0034
0035 if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
0036 err++;
0037
0038 return err;
0039 }
0040
0041 static void get_msk_ca_name(char ca_name[])
0042 {
0043 size_t len;
0044 int fd;
0045
0046 fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
0047 if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
0048 return;
0049
0050 len = read(fd, ca_name, TCP_CA_NAME_MAX);
0051 if (!ASSERT_GT(len, 0, "failed to read ca_name"))
0052 goto err;
0053
0054 if (len > 0 && ca_name[len - 1] == '\n')
0055 ca_name[len - 1] = '\0';
0056
0057 err:
0058 close(fd);
0059 }
0060
0061 static int verify_msk(int map_fd, int client_fd, __u32 token)
0062 {
0063 char ca_name[TCP_CA_NAME_MAX];
0064 int err, cfd = client_fd;
0065 struct mptcp_storage val;
0066
0067 if (!ASSERT_GT(token, 0, "invalid token"))
0068 return -1;
0069
0070 get_msk_ca_name(ca_name);
0071
0072 err = bpf_map_lookup_elem(map_fd, &cfd, &val);
0073 if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
0074 return err;
0075
0076 if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
0077 err++;
0078
0079 if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
0080 err++;
0081
0082 if (!ASSERT_EQ(val.token, token, "unexpected token"))
0083 err++;
0084
0085 if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
0086 err++;
0087
0088 if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
0089 err++;
0090
0091 return err;
0092 }
0093
0094 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
0095 {
0096 int client_fd, prog_fd, map_fd, err;
0097 struct mptcp_sock *sock_skel;
0098
0099 sock_skel = mptcp_sock__open_and_load();
0100 if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
0101 return -EIO;
0102
0103 err = mptcp_sock__attach(sock_skel);
0104 if (!ASSERT_OK(err, "skel_attach"))
0105 goto out;
0106
0107 prog_fd = bpf_program__fd(sock_skel->progs._sockops);
0108 if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) {
0109 err = -EIO;
0110 goto out;
0111 }
0112
0113 map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
0114 if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) {
0115 err = -EIO;
0116 goto out;
0117 }
0118
0119 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
0120 if (!ASSERT_OK(err, "bpf_prog_attach"))
0121 goto out;
0122
0123 client_fd = connect_to_fd(server_fd, 0);
0124 if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
0125 err = -EIO;
0126 goto out;
0127 }
0128
0129 err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
0130 verify_tsk(map_fd, client_fd);
0131
0132 close(client_fd);
0133
0134 out:
0135 mptcp_sock__destroy(sock_skel);
0136 return err;
0137 }
0138
0139 static void test_base(void)
0140 {
0141 int server_fd, cgroup_fd;
0142
0143 cgroup_fd = test__join_cgroup("/mptcp");
0144 if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
0145 return;
0146
0147
0148 server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
0149 if (!ASSERT_GE(server_fd, 0, "start_server"))
0150 goto with_mptcp;
0151
0152 ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
0153
0154 close(server_fd);
0155
0156 with_mptcp:
0157
0158 server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
0159 if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
0160 goto close_cgroup_fd;
0161
0162 ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
0163
0164 close(server_fd);
0165
0166 close_cgroup_fd:
0167 close(cgroup_fd);
0168 }
0169
0170 void test_mptcp(void)
0171 {
0172 if (test__start_subtest("base"))
0173 test_base();
0174 }