0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #include <getopt.h>
0011 #include <stdio.h>
0012 #include <stdlib.h>
0013 #include <string.h>
0014 #include <errno.h>
0015 #include <unistd.h>
0016 #include <sys/stat.h>
0017 #include <sys/types.h>
0018 #include <linux/list.h>
0019 #include <linux/net.h>
0020 #include <linux/netlink.h>
0021 #include <linux/sock_diag.h>
0022 #include <linux/vm_sockets_diag.h>
0023 #include <netinet/tcp.h>
0024
0025 #include "timeout.h"
0026 #include "control.h"
0027 #include "util.h"
0028
0029
0030 struct vsock_stat {
0031 struct list_head list;
0032 struct vsock_diag_msg msg;
0033 };
0034
0035 static const char *sock_type_str(int type)
0036 {
0037 switch (type) {
0038 case SOCK_DGRAM:
0039 return "DGRAM";
0040 case SOCK_STREAM:
0041 return "STREAM";
0042 default:
0043 return "INVALID TYPE";
0044 }
0045 }
0046
0047 static const char *sock_state_str(int state)
0048 {
0049 switch (state) {
0050 case TCP_CLOSE:
0051 return "UNCONNECTED";
0052 case TCP_SYN_SENT:
0053 return "CONNECTING";
0054 case TCP_ESTABLISHED:
0055 return "CONNECTED";
0056 case TCP_CLOSING:
0057 return "DISCONNECTING";
0058 case TCP_LISTEN:
0059 return "LISTEN";
0060 default:
0061 return "INVALID STATE";
0062 }
0063 }
0064
0065 static const char *sock_shutdown_str(int shutdown)
0066 {
0067 switch (shutdown) {
0068 case 1:
0069 return "RCV_SHUTDOWN";
0070 case 2:
0071 return "SEND_SHUTDOWN";
0072 case 3:
0073 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
0074 default:
0075 return "0";
0076 }
0077 }
0078
0079 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
0080 {
0081 if (cid == VMADDR_CID_ANY)
0082 fprintf(fp, "*:");
0083 else
0084 fprintf(fp, "%u:", cid);
0085
0086 if (port == VMADDR_PORT_ANY)
0087 fprintf(fp, "*");
0088 else
0089 fprintf(fp, "%u", port);
0090 }
0091
0092 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
0093 {
0094 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
0095 fprintf(fp, " ");
0096 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
0097 fprintf(fp, " %s %s %s %u\n",
0098 sock_type_str(st->msg.vdiag_type),
0099 sock_state_str(st->msg.vdiag_state),
0100 sock_shutdown_str(st->msg.vdiag_shutdown),
0101 st->msg.vdiag_ino);
0102 }
0103
0104 static void print_vsock_stats(FILE *fp, struct list_head *head)
0105 {
0106 struct vsock_stat *st;
0107
0108 list_for_each_entry(st, head, list)
0109 print_vsock_stat(fp, st);
0110 }
0111
0112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
0113 {
0114 struct vsock_stat *st;
0115 struct stat stat;
0116
0117 if (fstat(fd, &stat) < 0) {
0118 perror("fstat");
0119 exit(EXIT_FAILURE);
0120 }
0121
0122 list_for_each_entry(st, head, list)
0123 if (st->msg.vdiag_ino == stat.st_ino)
0124 return st;
0125
0126 fprintf(stderr, "cannot find fd %d\n", fd);
0127 exit(EXIT_FAILURE);
0128 }
0129
0130 static void check_no_sockets(struct list_head *head)
0131 {
0132 if (!list_empty(head)) {
0133 fprintf(stderr, "expected no sockets\n");
0134 print_vsock_stats(stderr, head);
0135 exit(1);
0136 }
0137 }
0138
0139 static void check_num_sockets(struct list_head *head, int expected)
0140 {
0141 struct list_head *node;
0142 int n = 0;
0143
0144 list_for_each(node, head)
0145 n++;
0146
0147 if (n != expected) {
0148 fprintf(stderr, "expected %d sockets, found %d\n",
0149 expected, n);
0150 print_vsock_stats(stderr, head);
0151 exit(EXIT_FAILURE);
0152 }
0153 }
0154
0155 static void check_socket_state(struct vsock_stat *st, __u8 state)
0156 {
0157 if (st->msg.vdiag_state != state) {
0158 fprintf(stderr, "expected socket state %#x, got %#x\n",
0159 state, st->msg.vdiag_state);
0160 exit(EXIT_FAILURE);
0161 }
0162 }
0163
0164 static void send_req(int fd)
0165 {
0166 struct sockaddr_nl nladdr = {
0167 .nl_family = AF_NETLINK,
0168 };
0169 struct {
0170 struct nlmsghdr nlh;
0171 struct vsock_diag_req vreq;
0172 } req = {
0173 .nlh = {
0174 .nlmsg_len = sizeof(req),
0175 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
0176 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
0177 },
0178 .vreq = {
0179 .sdiag_family = AF_VSOCK,
0180 .vdiag_states = ~(__u32)0,
0181 },
0182 };
0183 struct iovec iov = {
0184 .iov_base = &req,
0185 .iov_len = sizeof(req),
0186 };
0187 struct msghdr msg = {
0188 .msg_name = &nladdr,
0189 .msg_namelen = sizeof(nladdr),
0190 .msg_iov = &iov,
0191 .msg_iovlen = 1,
0192 };
0193
0194 for (;;) {
0195 if (sendmsg(fd, &msg, 0) < 0) {
0196 if (errno == EINTR)
0197 continue;
0198
0199 perror("sendmsg");
0200 exit(EXIT_FAILURE);
0201 }
0202
0203 return;
0204 }
0205 }
0206
0207 static ssize_t recv_resp(int fd, void *buf, size_t len)
0208 {
0209 struct sockaddr_nl nladdr = {
0210 .nl_family = AF_NETLINK,
0211 };
0212 struct iovec iov = {
0213 .iov_base = buf,
0214 .iov_len = len,
0215 };
0216 struct msghdr msg = {
0217 .msg_name = &nladdr,
0218 .msg_namelen = sizeof(nladdr),
0219 .msg_iov = &iov,
0220 .msg_iovlen = 1,
0221 };
0222 ssize_t ret;
0223
0224 do {
0225 ret = recvmsg(fd, &msg, 0);
0226 } while (ret < 0 && errno == EINTR);
0227
0228 if (ret < 0) {
0229 perror("recvmsg");
0230 exit(EXIT_FAILURE);
0231 }
0232
0233 return ret;
0234 }
0235
0236 static void add_vsock_stat(struct list_head *sockets,
0237 const struct vsock_diag_msg *resp)
0238 {
0239 struct vsock_stat *st;
0240
0241 st = malloc(sizeof(*st));
0242 if (!st) {
0243 perror("malloc");
0244 exit(EXIT_FAILURE);
0245 }
0246
0247 st->msg = *resp;
0248 list_add_tail(&st->list, sockets);
0249 }
0250
0251
0252
0253
0254 static void read_vsock_stat(struct list_head *sockets)
0255 {
0256 long buf[8192 / sizeof(long)];
0257 int fd;
0258
0259 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
0260 if (fd < 0) {
0261 perror("socket");
0262 exit(EXIT_FAILURE);
0263 }
0264
0265 send_req(fd);
0266
0267 for (;;) {
0268 const struct nlmsghdr *h;
0269 ssize_t ret;
0270
0271 ret = recv_resp(fd, buf, sizeof(buf));
0272 if (ret == 0)
0273 goto done;
0274 if (ret < sizeof(*h)) {
0275 fprintf(stderr, "short read of %zd bytes\n", ret);
0276 exit(EXIT_FAILURE);
0277 }
0278
0279 h = (struct nlmsghdr *)buf;
0280
0281 while (NLMSG_OK(h, ret)) {
0282 if (h->nlmsg_type == NLMSG_DONE)
0283 goto done;
0284
0285 if (h->nlmsg_type == NLMSG_ERROR) {
0286 const struct nlmsgerr *err = NLMSG_DATA(h);
0287
0288 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
0289 fprintf(stderr, "NLMSG_ERROR\n");
0290 else {
0291 errno = -err->error;
0292 perror("NLMSG_ERROR");
0293 }
0294
0295 exit(EXIT_FAILURE);
0296 }
0297
0298 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
0299 fprintf(stderr, "unexpected nlmsg_type %#x\n",
0300 h->nlmsg_type);
0301 exit(EXIT_FAILURE);
0302 }
0303 if (h->nlmsg_len <
0304 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
0305 fprintf(stderr, "short vsock_diag_msg\n");
0306 exit(EXIT_FAILURE);
0307 }
0308
0309 add_vsock_stat(sockets, NLMSG_DATA(h));
0310
0311 h = NLMSG_NEXT(h, ret);
0312 }
0313 }
0314
0315 done:
0316 close(fd);
0317 }
0318
0319 static void free_sock_stat(struct list_head *sockets)
0320 {
0321 struct vsock_stat *st;
0322 struct vsock_stat *next;
0323
0324 list_for_each_entry_safe(st, next, sockets, list)
0325 free(st);
0326 }
0327
0328 static void test_no_sockets(const struct test_opts *opts)
0329 {
0330 LIST_HEAD(sockets);
0331
0332 read_vsock_stat(&sockets);
0333
0334 check_no_sockets(&sockets);
0335 }
0336
0337 static void test_listen_socket_server(const struct test_opts *opts)
0338 {
0339 union {
0340 struct sockaddr sa;
0341 struct sockaddr_vm svm;
0342 } addr = {
0343 .svm = {
0344 .svm_family = AF_VSOCK,
0345 .svm_port = 1234,
0346 .svm_cid = VMADDR_CID_ANY,
0347 },
0348 };
0349 LIST_HEAD(sockets);
0350 struct vsock_stat *st;
0351 int fd;
0352
0353 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
0354
0355 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
0356 perror("bind");
0357 exit(EXIT_FAILURE);
0358 }
0359
0360 if (listen(fd, 1) < 0) {
0361 perror("listen");
0362 exit(EXIT_FAILURE);
0363 }
0364
0365 read_vsock_stat(&sockets);
0366
0367 check_num_sockets(&sockets, 1);
0368 st = find_vsock_stat(&sockets, fd);
0369 check_socket_state(st, TCP_LISTEN);
0370
0371 close(fd);
0372 free_sock_stat(&sockets);
0373 }
0374
0375 static void test_connect_client(const struct test_opts *opts)
0376 {
0377 int fd;
0378 LIST_HEAD(sockets);
0379 struct vsock_stat *st;
0380
0381 fd = vsock_stream_connect(opts->peer_cid, 1234);
0382 if (fd < 0) {
0383 perror("connect");
0384 exit(EXIT_FAILURE);
0385 }
0386
0387 read_vsock_stat(&sockets);
0388
0389 check_num_sockets(&sockets, 1);
0390 st = find_vsock_stat(&sockets, fd);
0391 check_socket_state(st, TCP_ESTABLISHED);
0392
0393 control_expectln("DONE");
0394 control_writeln("DONE");
0395
0396 close(fd);
0397 free_sock_stat(&sockets);
0398 }
0399
0400 static void test_connect_server(const struct test_opts *opts)
0401 {
0402 struct vsock_stat *st;
0403 LIST_HEAD(sockets);
0404 int client_fd;
0405
0406 client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
0407 if (client_fd < 0) {
0408 perror("accept");
0409 exit(EXIT_FAILURE);
0410 }
0411
0412 read_vsock_stat(&sockets);
0413
0414 check_num_sockets(&sockets, 1);
0415 st = find_vsock_stat(&sockets, client_fd);
0416 check_socket_state(st, TCP_ESTABLISHED);
0417
0418 control_writeln("DONE");
0419 control_expectln("DONE");
0420
0421 close(client_fd);
0422 free_sock_stat(&sockets);
0423 }
0424
0425 static struct test_case test_cases[] = {
0426 {
0427 .name = "No sockets",
0428 .run_server = test_no_sockets,
0429 },
0430 {
0431 .name = "Listen socket",
0432 .run_server = test_listen_socket_server,
0433 },
0434 {
0435 .name = "Connect",
0436 .run_client = test_connect_client,
0437 .run_server = test_connect_server,
0438 },
0439 {},
0440 };
0441
0442 static const char optstring[] = "";
0443 static const struct option longopts[] = {
0444 {
0445 .name = "control-host",
0446 .has_arg = required_argument,
0447 .val = 'H',
0448 },
0449 {
0450 .name = "control-port",
0451 .has_arg = required_argument,
0452 .val = 'P',
0453 },
0454 {
0455 .name = "mode",
0456 .has_arg = required_argument,
0457 .val = 'm',
0458 },
0459 {
0460 .name = "peer-cid",
0461 .has_arg = required_argument,
0462 .val = 'p',
0463 },
0464 {
0465 .name = "list",
0466 .has_arg = no_argument,
0467 .val = 'l',
0468 },
0469 {
0470 .name = "skip",
0471 .has_arg = required_argument,
0472 .val = 's',
0473 },
0474 {
0475 .name = "help",
0476 .has_arg = no_argument,
0477 .val = '?',
0478 },
0479 {},
0480 };
0481
0482 static void usage(void)
0483 {
0484 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
0485 "\n"
0486 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
0487 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
0488 "\n"
0489 "Run vsock_diag.ko tests. Must be launched in both\n"
0490 "guest and host. One side must use --mode=client and\n"
0491 "the other side must use --mode=server.\n"
0492 "\n"
0493 "A TCP control socket connection is used to coordinate tests\n"
0494 "between the client and the server. The server requires a\n"
0495 "listen address and the client requires an address to\n"
0496 "connect to.\n"
0497 "\n"
0498 "The CID of the other side must be given with --peer-cid=<cid>.\n"
0499 "\n"
0500 "Options:\n"
0501 " --help This help message\n"
0502 " --control-host <host> Server IP address to connect to\n"
0503 " --control-port <port> Server port to listen on/connect to\n"
0504 " --mode client|server Server or client mode\n"
0505 " --peer-cid <cid> CID of the other side\n"
0506 " --list List of tests that will be executed\n"
0507 " --skip <test_id> Test ID to skip;\n"
0508 " use multiple --skip options to skip more tests\n"
0509 );
0510 exit(EXIT_FAILURE);
0511 }
0512
0513 int main(int argc, char **argv)
0514 {
0515 const char *control_host = NULL;
0516 const char *control_port = NULL;
0517 struct test_opts opts = {
0518 .mode = TEST_MODE_UNSET,
0519 .peer_cid = VMADDR_CID_ANY,
0520 };
0521
0522 init_signals();
0523
0524 for (;;) {
0525 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
0526
0527 if (opt == -1)
0528 break;
0529
0530 switch (opt) {
0531 case 'H':
0532 control_host = optarg;
0533 break;
0534 case 'm':
0535 if (strcmp(optarg, "client") == 0)
0536 opts.mode = TEST_MODE_CLIENT;
0537 else if (strcmp(optarg, "server") == 0)
0538 opts.mode = TEST_MODE_SERVER;
0539 else {
0540 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
0541 return EXIT_FAILURE;
0542 }
0543 break;
0544 case 'p':
0545 opts.peer_cid = parse_cid(optarg);
0546 break;
0547 case 'P':
0548 control_port = optarg;
0549 break;
0550 case 'l':
0551 list_tests(test_cases);
0552 break;
0553 case 's':
0554 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
0555 optarg);
0556 break;
0557 case '?':
0558 default:
0559 usage();
0560 }
0561 }
0562
0563 if (!control_port)
0564 usage();
0565 if (opts.mode == TEST_MODE_UNSET)
0566 usage();
0567 if (opts.peer_cid == VMADDR_CID_ANY)
0568 usage();
0569
0570 if (!control_host) {
0571 if (opts.mode != TEST_MODE_SERVER)
0572 usage();
0573 control_host = "0.0.0.0";
0574 }
0575
0576 control_init(control_host, control_port,
0577 opts.mode == TEST_MODE_SERVER);
0578
0579 run_tests(test_cases, &opts);
0580
0581 control_cleanup();
0582 return EXIT_SUCCESS;
0583 }