Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 #include <test_progs.h>
0003 #include "cgroup_helpers.h"
0004 
0005 #define SOL_CUSTOM          0xdeadbeef
0006 #define CUSTOM_INHERIT1         0
0007 #define CUSTOM_INHERIT2         1
0008 #define CUSTOM_LISTENER         2
0009 
0010 static int connect_to_server(int server_fd)
0011 {
0012     struct sockaddr_storage addr;
0013     socklen_t len = sizeof(addr);
0014     int fd;
0015 
0016     fd = socket(AF_INET, SOCK_STREAM, 0);
0017     if (fd < 0) {
0018         log_err("Failed to create client socket");
0019         return -1;
0020     }
0021 
0022     if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
0023         log_err("Failed to get server addr");
0024         goto out;
0025     }
0026 
0027     if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
0028         log_err("Fail to connect to server");
0029         goto out;
0030     }
0031 
0032     return fd;
0033 
0034 out:
0035     close(fd);
0036     return -1;
0037 }
0038 
0039 static int verify_sockopt(int fd, int optname, const char *msg, char expected)
0040 {
0041     socklen_t optlen = 1;
0042     char buf = 0;
0043     int err;
0044 
0045     err = getsockopt(fd, SOL_CUSTOM, optname, &buf, &optlen);
0046     if (err) {
0047         log_err("%s: failed to call getsockopt", msg);
0048         return 1;
0049     }
0050 
0051     printf("%s %d: got=0x%x ? expected=0x%x\n", msg, optname, buf, expected);
0052 
0053     if (buf != expected) {
0054         log_err("%s: unexpected getsockopt value %d != %d", msg,
0055             buf, expected);
0056         return 1;
0057     }
0058 
0059     return 0;
0060 }
0061 
0062 static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER;
0063 static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER;
0064 
0065 static void *server_thread(void *arg)
0066 {
0067     struct sockaddr_storage addr;
0068     socklen_t len = sizeof(addr);
0069     int fd = *(int *)arg;
0070     int client_fd;
0071     int err = 0;
0072 
0073     err = listen(fd, 1);
0074 
0075     pthread_mutex_lock(&server_started_mtx);
0076     pthread_cond_signal(&server_started);
0077     pthread_mutex_unlock(&server_started_mtx);
0078 
0079     if (CHECK_FAIL(err < 0)) {
0080         perror("Failed to listed on socket");
0081         return NULL;
0082     }
0083 
0084     err += verify_sockopt(fd, CUSTOM_INHERIT1, "listen", 1);
0085     err += verify_sockopt(fd, CUSTOM_INHERIT2, "listen", 1);
0086     err += verify_sockopt(fd, CUSTOM_LISTENER, "listen", 1);
0087 
0088     client_fd = accept(fd, (struct sockaddr *)&addr, &len);
0089     if (CHECK_FAIL(client_fd < 0)) {
0090         perror("Failed to accept client");
0091         return NULL;
0092     }
0093 
0094     err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "accept", 1);
0095     err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "accept", 1);
0096     err += verify_sockopt(client_fd, CUSTOM_LISTENER, "accept", 0);
0097 
0098     close(client_fd);
0099 
0100     return (void *)(long)err;
0101 }
0102 
0103 static int start_server(void)
0104 {
0105     struct sockaddr_in addr = {
0106         .sin_family = AF_INET,
0107         .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
0108     };
0109     char buf;
0110     int err;
0111     int fd;
0112     int i;
0113 
0114     fd = socket(AF_INET, SOCK_STREAM, 0);
0115     if (fd < 0) {
0116         log_err("Failed to create server socket");
0117         return -1;
0118     }
0119 
0120     for (i = CUSTOM_INHERIT1; i <= CUSTOM_LISTENER; i++) {
0121         buf = 0x01;
0122         err = setsockopt(fd, SOL_CUSTOM, i, &buf, 1);
0123         if (err) {
0124             log_err("Failed to call setsockopt(%d)", i);
0125             close(fd);
0126             return -1;
0127         }
0128     }
0129 
0130     if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
0131         log_err("Failed to bind socket");
0132         close(fd);
0133         return -1;
0134     }
0135 
0136     return fd;
0137 }
0138 
0139 static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title,
0140                const char *prog_name)
0141 {
0142     enum bpf_attach_type attach_type;
0143     enum bpf_prog_type prog_type;
0144     struct bpf_program *prog;
0145     int err;
0146 
0147     err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
0148     if (err) {
0149         log_err("Failed to deduct types for %s BPF program", prog_name);
0150         return -1;
0151     }
0152 
0153     prog = bpf_object__find_program_by_name(obj, prog_name);
0154     if (!prog) {
0155         log_err("Failed to find %s BPF program", prog_name);
0156         return -1;
0157     }
0158 
0159     err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
0160                   attach_type, 0);
0161     if (err) {
0162         log_err("Failed to attach %s BPF program", prog_name);
0163         return -1;
0164     }
0165 
0166     return 0;
0167 }
0168 
0169 static void run_test(int cgroup_fd)
0170 {
0171     int server_fd = -1, client_fd;
0172     struct bpf_object *obj;
0173     void *server_err;
0174     pthread_t tid;
0175     int err;
0176 
0177     obj = bpf_object__open_file("sockopt_inherit.o", NULL);
0178     if (!ASSERT_OK_PTR(obj, "obj_open"))
0179         return;
0180 
0181     err = bpf_object__load(obj);
0182     if (!ASSERT_OK(err, "obj_load"))
0183         goto close_bpf_object;
0184 
0185     err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt", "_getsockopt");
0186     if (CHECK_FAIL(err))
0187         goto close_bpf_object;
0188 
0189     err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt", "_setsockopt");
0190     if (CHECK_FAIL(err))
0191         goto close_bpf_object;
0192 
0193     server_fd = start_server();
0194     if (CHECK_FAIL(server_fd < 0))
0195         goto close_bpf_object;
0196 
0197     pthread_mutex_lock(&server_started_mtx);
0198     if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread,
0199                       (void *)&server_fd))) {
0200         pthread_mutex_unlock(&server_started_mtx);
0201         goto close_server_fd;
0202     }
0203     pthread_cond_wait(&server_started, &server_started_mtx);
0204     pthread_mutex_unlock(&server_started_mtx);
0205 
0206     client_fd = connect_to_server(server_fd);
0207     if (CHECK_FAIL(client_fd < 0))
0208         goto close_server_fd;
0209 
0210     CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT1, "connect", 0));
0211     CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT2, "connect", 0));
0212     CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_LISTENER, "connect", 0));
0213 
0214     pthread_join(tid, &server_err);
0215 
0216     err = (int)(long)server_err;
0217     CHECK_FAIL(err);
0218 
0219     close(client_fd);
0220 
0221 close_server_fd:
0222     close(server_fd);
0223 close_bpf_object:
0224     bpf_object__close(obj);
0225 }
0226 
0227 void test_sockopt_inherit(void)
0228 {
0229     int cgroup_fd;
0230 
0231     cgroup_fd = test__join_cgroup("/sockopt_inherit");
0232     if (CHECK_FAIL(cgroup_fd < 0))
0233         return;
0234 
0235     run_test(cgroup_fd);
0236     close(cgroup_fd);
0237 }