Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /* Copyright (c) 2020, Tessares SA. */
0003 /* Copyright (c) 2022, SUSE. */
0004 
0005 #include <test_progs.h>
0006 #include "cgroup_helpers.h"
0007 #include "network_helpers.h"
0008 #include "mptcp_sock.skel.h"
0009 
0010 #ifndef TCP_CA_NAME_MAX
0011 #define TCP_CA_NAME_MAX 16
0012 #endif
0013 
0014 struct mptcp_storage {
0015     __u32 invoked;
0016     __u32 is_mptcp;
0017     struct sock *sk;
0018     __u32 token;
0019     struct sock *first;
0020     char ca_name[TCP_CA_NAME_MAX];
0021 };
0022 
0023 static int verify_tsk(int map_fd, int client_fd)
0024 {
0025     int err, cfd = client_fd;
0026     struct mptcp_storage val;
0027 
0028     err = bpf_map_lookup_elem(map_fd, &cfd, &val);
0029     if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
0030         return err;
0031 
0032     if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
0033         err++;
0034 
0035     if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
0036         err++;
0037 
0038     return err;
0039 }
0040 
0041 static void get_msk_ca_name(char ca_name[])
0042 {
0043     size_t len;
0044     int fd;
0045 
0046     fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
0047     if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
0048         return;
0049 
0050     len = read(fd, ca_name, TCP_CA_NAME_MAX);
0051     if (!ASSERT_GT(len, 0, "failed to read ca_name"))
0052         goto err;
0053 
0054     if (len > 0 && ca_name[len - 1] == '\n')
0055         ca_name[len - 1] = '\0';
0056 
0057 err:
0058     close(fd);
0059 }
0060 
0061 static int verify_msk(int map_fd, int client_fd, __u32 token)
0062 {
0063     char ca_name[TCP_CA_NAME_MAX];
0064     int err, cfd = client_fd;
0065     struct mptcp_storage val;
0066 
0067     if (!ASSERT_GT(token, 0, "invalid token"))
0068         return -1;
0069 
0070     get_msk_ca_name(ca_name);
0071 
0072     err = bpf_map_lookup_elem(map_fd, &cfd, &val);
0073     if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
0074         return err;
0075 
0076     if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
0077         err++;
0078 
0079     if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
0080         err++;
0081 
0082     if (!ASSERT_EQ(val.token, token, "unexpected token"))
0083         err++;
0084 
0085     if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
0086         err++;
0087 
0088     if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
0089         err++;
0090 
0091     return err;
0092 }
0093 
0094 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
0095 {
0096     int client_fd, prog_fd, map_fd, err;
0097     struct mptcp_sock *sock_skel;
0098 
0099     sock_skel = mptcp_sock__open_and_load();
0100     if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
0101         return -EIO;
0102 
0103     err = mptcp_sock__attach(sock_skel);
0104     if (!ASSERT_OK(err, "skel_attach"))
0105         goto out;
0106 
0107     prog_fd = bpf_program__fd(sock_skel->progs._sockops);
0108     if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) {
0109         err = -EIO;
0110         goto out;
0111     }
0112 
0113     map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
0114     if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) {
0115         err = -EIO;
0116         goto out;
0117     }
0118 
0119     err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
0120     if (!ASSERT_OK(err, "bpf_prog_attach"))
0121         goto out;
0122 
0123     client_fd = connect_to_fd(server_fd, 0);
0124     if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
0125         err = -EIO;
0126         goto out;
0127     }
0128 
0129     err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
0130               verify_tsk(map_fd, client_fd);
0131 
0132     close(client_fd);
0133 
0134 out:
0135     mptcp_sock__destroy(sock_skel);
0136     return err;
0137 }
0138 
0139 static void test_base(void)
0140 {
0141     int server_fd, cgroup_fd;
0142 
0143     cgroup_fd = test__join_cgroup("/mptcp");
0144     if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
0145         return;
0146 
0147     /* without MPTCP */
0148     server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
0149     if (!ASSERT_GE(server_fd, 0, "start_server"))
0150         goto with_mptcp;
0151 
0152     ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
0153 
0154     close(server_fd);
0155 
0156 with_mptcp:
0157     /* with MPTCP */
0158     server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
0159     if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
0160         goto close_cgroup_fd;
0161 
0162     ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
0163 
0164     close(server_fd);
0165 
0166 close_cgroup_fd:
0167     close(cgroup_fd);
0168 }
0169 
0170 void test_mptcp(void)
0171 {
0172     if (test__start_subtest("base"))
0173         test_base();
0174 }