0001
0002 #include <test_progs.h>
0003 #include "cgroup_helpers.h"
0004
0005 #define SOL_CUSTOM 0xdeadbeef
0006 #define CUSTOM_INHERIT1 0
0007 #define CUSTOM_INHERIT2 1
0008 #define CUSTOM_LISTENER 2
0009
0010 static int connect_to_server(int server_fd)
0011 {
0012 struct sockaddr_storage addr;
0013 socklen_t len = sizeof(addr);
0014 int fd;
0015
0016 fd = socket(AF_INET, SOCK_STREAM, 0);
0017 if (fd < 0) {
0018 log_err("Failed to create client socket");
0019 return -1;
0020 }
0021
0022 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
0023 log_err("Failed to get server addr");
0024 goto out;
0025 }
0026
0027 if (connect(fd, (const struct sockaddr *)&addr, len) < 0) {
0028 log_err("Fail to connect to server");
0029 goto out;
0030 }
0031
0032 return fd;
0033
0034 out:
0035 close(fd);
0036 return -1;
0037 }
0038
0039 static int verify_sockopt(int fd, int optname, const char *msg, char expected)
0040 {
0041 socklen_t optlen = 1;
0042 char buf = 0;
0043 int err;
0044
0045 err = getsockopt(fd, SOL_CUSTOM, optname, &buf, &optlen);
0046 if (err) {
0047 log_err("%s: failed to call getsockopt", msg);
0048 return 1;
0049 }
0050
0051 printf("%s %d: got=0x%x ? expected=0x%x\n", msg, optname, buf, expected);
0052
0053 if (buf != expected) {
0054 log_err("%s: unexpected getsockopt value %d != %d", msg,
0055 buf, expected);
0056 return 1;
0057 }
0058
0059 return 0;
0060 }
0061
0062 static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER;
0063 static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER;
0064
0065 static void *server_thread(void *arg)
0066 {
0067 struct sockaddr_storage addr;
0068 socklen_t len = sizeof(addr);
0069 int fd = *(int *)arg;
0070 int client_fd;
0071 int err = 0;
0072
0073 err = listen(fd, 1);
0074
0075 pthread_mutex_lock(&server_started_mtx);
0076 pthread_cond_signal(&server_started);
0077 pthread_mutex_unlock(&server_started_mtx);
0078
0079 if (CHECK_FAIL(err < 0)) {
0080 perror("Failed to listed on socket");
0081 return NULL;
0082 }
0083
0084 err += verify_sockopt(fd, CUSTOM_INHERIT1, "listen", 1);
0085 err += verify_sockopt(fd, CUSTOM_INHERIT2, "listen", 1);
0086 err += verify_sockopt(fd, CUSTOM_LISTENER, "listen", 1);
0087
0088 client_fd = accept(fd, (struct sockaddr *)&addr, &len);
0089 if (CHECK_FAIL(client_fd < 0)) {
0090 perror("Failed to accept client");
0091 return NULL;
0092 }
0093
0094 err += verify_sockopt(client_fd, CUSTOM_INHERIT1, "accept", 1);
0095 err += verify_sockopt(client_fd, CUSTOM_INHERIT2, "accept", 1);
0096 err += verify_sockopt(client_fd, CUSTOM_LISTENER, "accept", 0);
0097
0098 close(client_fd);
0099
0100 return (void *)(long)err;
0101 }
0102
0103 static int start_server(void)
0104 {
0105 struct sockaddr_in addr = {
0106 .sin_family = AF_INET,
0107 .sin_addr.s_addr = htonl(INADDR_LOOPBACK),
0108 };
0109 char buf;
0110 int err;
0111 int fd;
0112 int i;
0113
0114 fd = socket(AF_INET, SOCK_STREAM, 0);
0115 if (fd < 0) {
0116 log_err("Failed to create server socket");
0117 return -1;
0118 }
0119
0120 for (i = CUSTOM_INHERIT1; i <= CUSTOM_LISTENER; i++) {
0121 buf = 0x01;
0122 err = setsockopt(fd, SOL_CUSTOM, i, &buf, 1);
0123 if (err) {
0124 log_err("Failed to call setsockopt(%d)", i);
0125 close(fd);
0126 return -1;
0127 }
0128 }
0129
0130 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) {
0131 log_err("Failed to bind socket");
0132 close(fd);
0133 return -1;
0134 }
0135
0136 return fd;
0137 }
0138
0139 static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title,
0140 const char *prog_name)
0141 {
0142 enum bpf_attach_type attach_type;
0143 enum bpf_prog_type prog_type;
0144 struct bpf_program *prog;
0145 int err;
0146
0147 err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
0148 if (err) {
0149 log_err("Failed to deduct types for %s BPF program", prog_name);
0150 return -1;
0151 }
0152
0153 prog = bpf_object__find_program_by_name(obj, prog_name);
0154 if (!prog) {
0155 log_err("Failed to find %s BPF program", prog_name);
0156 return -1;
0157 }
0158
0159 err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
0160 attach_type, 0);
0161 if (err) {
0162 log_err("Failed to attach %s BPF program", prog_name);
0163 return -1;
0164 }
0165
0166 return 0;
0167 }
0168
0169 static void run_test(int cgroup_fd)
0170 {
0171 int server_fd = -1, client_fd;
0172 struct bpf_object *obj;
0173 void *server_err;
0174 pthread_t tid;
0175 int err;
0176
0177 obj = bpf_object__open_file("sockopt_inherit.o", NULL);
0178 if (!ASSERT_OK_PTR(obj, "obj_open"))
0179 return;
0180
0181 err = bpf_object__load(obj);
0182 if (!ASSERT_OK(err, "obj_load"))
0183 goto close_bpf_object;
0184
0185 err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt", "_getsockopt");
0186 if (CHECK_FAIL(err))
0187 goto close_bpf_object;
0188
0189 err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt", "_setsockopt");
0190 if (CHECK_FAIL(err))
0191 goto close_bpf_object;
0192
0193 server_fd = start_server();
0194 if (CHECK_FAIL(server_fd < 0))
0195 goto close_bpf_object;
0196
0197 pthread_mutex_lock(&server_started_mtx);
0198 if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread,
0199 (void *)&server_fd))) {
0200 pthread_mutex_unlock(&server_started_mtx);
0201 goto close_server_fd;
0202 }
0203 pthread_cond_wait(&server_started, &server_started_mtx);
0204 pthread_mutex_unlock(&server_started_mtx);
0205
0206 client_fd = connect_to_server(server_fd);
0207 if (CHECK_FAIL(client_fd < 0))
0208 goto close_server_fd;
0209
0210 CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT1, "connect", 0));
0211 CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_INHERIT2, "connect", 0));
0212 CHECK_FAIL(verify_sockopt(client_fd, CUSTOM_LISTENER, "connect", 0));
0213
0214 pthread_join(tid, &server_err);
0215
0216 err = (int)(long)server_err;
0217 CHECK_FAIL(err);
0218
0219 close(client_fd);
0220
0221 close_server_fd:
0222 close(server_fd);
0223 close_bpf_object:
0224 bpf_object__close(obj);
0225 }
0226
0227 void test_sockopt_inherit(void)
0228 {
0229 int cgroup_fd;
0230
0231 cgroup_fd = test__join_cgroup("/sockopt_inherit");
0232 if (CHECK_FAIL(cgroup_fd < 0))
0233 return;
0234
0235 run_test(cgroup_fd);
0236 close(cgroup_fd);
0237 }