Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * vsock_diag_test - vsock_diag.ko test suite
0004  *
0005  * Copyright (C) 2017 Red Hat, Inc.
0006  *
0007  * Author: Stefan Hajnoczi <stefanha@redhat.com>
0008  */
0009 
0010 #include <getopt.h>
0011 #include <stdio.h>
0012 #include <stdlib.h>
0013 #include <string.h>
0014 #include <errno.h>
0015 #include <unistd.h>
0016 #include <sys/stat.h>
0017 #include <sys/types.h>
0018 #include <linux/list.h>
0019 #include <linux/net.h>
0020 #include <linux/netlink.h>
0021 #include <linux/sock_diag.h>
0022 #include <linux/vm_sockets_diag.h>
0023 #include <netinet/tcp.h>
0024 
0025 #include "timeout.h"
0026 #include "control.h"
0027 #include "util.h"
0028 
0029 /* Per-socket status */
0030 struct vsock_stat {
0031     struct list_head list;
0032     struct vsock_diag_msg msg;
0033 };
0034 
0035 static const char *sock_type_str(int type)
0036 {
0037     switch (type) {
0038     case SOCK_DGRAM:
0039         return "DGRAM";
0040     case SOCK_STREAM:
0041         return "STREAM";
0042     default:
0043         return "INVALID TYPE";
0044     }
0045 }
0046 
0047 static const char *sock_state_str(int state)
0048 {
0049     switch (state) {
0050     case TCP_CLOSE:
0051         return "UNCONNECTED";
0052     case TCP_SYN_SENT:
0053         return "CONNECTING";
0054     case TCP_ESTABLISHED:
0055         return "CONNECTED";
0056     case TCP_CLOSING:
0057         return "DISCONNECTING";
0058     case TCP_LISTEN:
0059         return "LISTEN";
0060     default:
0061         return "INVALID STATE";
0062     }
0063 }
0064 
0065 static const char *sock_shutdown_str(int shutdown)
0066 {
0067     switch (shutdown) {
0068     case 1:
0069         return "RCV_SHUTDOWN";
0070     case 2:
0071         return "SEND_SHUTDOWN";
0072     case 3:
0073         return "RCV_SHUTDOWN | SEND_SHUTDOWN";
0074     default:
0075         return "0";
0076     }
0077 }
0078 
0079 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
0080 {
0081     if (cid == VMADDR_CID_ANY)
0082         fprintf(fp, "*:");
0083     else
0084         fprintf(fp, "%u:", cid);
0085 
0086     if (port == VMADDR_PORT_ANY)
0087         fprintf(fp, "*");
0088     else
0089         fprintf(fp, "%u", port);
0090 }
0091 
0092 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
0093 {
0094     print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
0095     fprintf(fp, " ");
0096     print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
0097     fprintf(fp, " %s %s %s %u\n",
0098         sock_type_str(st->msg.vdiag_type),
0099         sock_state_str(st->msg.vdiag_state),
0100         sock_shutdown_str(st->msg.vdiag_shutdown),
0101         st->msg.vdiag_ino);
0102 }
0103 
0104 static void print_vsock_stats(FILE *fp, struct list_head *head)
0105 {
0106     struct vsock_stat *st;
0107 
0108     list_for_each_entry(st, head, list)
0109         print_vsock_stat(fp, st);
0110 }
0111 
0112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
0113 {
0114     struct vsock_stat *st;
0115     struct stat stat;
0116 
0117     if (fstat(fd, &stat) < 0) {
0118         perror("fstat");
0119         exit(EXIT_FAILURE);
0120     }
0121 
0122     list_for_each_entry(st, head, list)
0123         if (st->msg.vdiag_ino == stat.st_ino)
0124             return st;
0125 
0126     fprintf(stderr, "cannot find fd %d\n", fd);
0127     exit(EXIT_FAILURE);
0128 }
0129 
0130 static void check_no_sockets(struct list_head *head)
0131 {
0132     if (!list_empty(head)) {
0133         fprintf(stderr, "expected no sockets\n");
0134         print_vsock_stats(stderr, head);
0135         exit(1);
0136     }
0137 }
0138 
0139 static void check_num_sockets(struct list_head *head, int expected)
0140 {
0141     struct list_head *node;
0142     int n = 0;
0143 
0144     list_for_each(node, head)
0145         n++;
0146 
0147     if (n != expected) {
0148         fprintf(stderr, "expected %d sockets, found %d\n",
0149             expected, n);
0150         print_vsock_stats(stderr, head);
0151         exit(EXIT_FAILURE);
0152     }
0153 }
0154 
0155 static void check_socket_state(struct vsock_stat *st, __u8 state)
0156 {
0157     if (st->msg.vdiag_state != state) {
0158         fprintf(stderr, "expected socket state %#x, got %#x\n",
0159             state, st->msg.vdiag_state);
0160         exit(EXIT_FAILURE);
0161     }
0162 }
0163 
0164 static void send_req(int fd)
0165 {
0166     struct sockaddr_nl nladdr = {
0167         .nl_family = AF_NETLINK,
0168     };
0169     struct {
0170         struct nlmsghdr nlh;
0171         struct vsock_diag_req vreq;
0172     } req = {
0173         .nlh = {
0174             .nlmsg_len = sizeof(req),
0175             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
0176             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
0177         },
0178         .vreq = {
0179             .sdiag_family = AF_VSOCK,
0180             .vdiag_states = ~(__u32)0,
0181         },
0182     };
0183     struct iovec iov = {
0184         .iov_base = &req,
0185         .iov_len = sizeof(req),
0186     };
0187     struct msghdr msg = {
0188         .msg_name = &nladdr,
0189         .msg_namelen = sizeof(nladdr),
0190         .msg_iov = &iov,
0191         .msg_iovlen = 1,
0192     };
0193 
0194     for (;;) {
0195         if (sendmsg(fd, &msg, 0) < 0) {
0196             if (errno == EINTR)
0197                 continue;
0198 
0199             perror("sendmsg");
0200             exit(EXIT_FAILURE);
0201         }
0202 
0203         return;
0204     }
0205 }
0206 
0207 static ssize_t recv_resp(int fd, void *buf, size_t len)
0208 {
0209     struct sockaddr_nl nladdr = {
0210         .nl_family = AF_NETLINK,
0211     };
0212     struct iovec iov = {
0213         .iov_base = buf,
0214         .iov_len = len,
0215     };
0216     struct msghdr msg = {
0217         .msg_name = &nladdr,
0218         .msg_namelen = sizeof(nladdr),
0219         .msg_iov = &iov,
0220         .msg_iovlen = 1,
0221     };
0222     ssize_t ret;
0223 
0224     do {
0225         ret = recvmsg(fd, &msg, 0);
0226     } while (ret < 0 && errno == EINTR);
0227 
0228     if (ret < 0) {
0229         perror("recvmsg");
0230         exit(EXIT_FAILURE);
0231     }
0232 
0233     return ret;
0234 }
0235 
0236 static void add_vsock_stat(struct list_head *sockets,
0237                const struct vsock_diag_msg *resp)
0238 {
0239     struct vsock_stat *st;
0240 
0241     st = malloc(sizeof(*st));
0242     if (!st) {
0243         perror("malloc");
0244         exit(EXIT_FAILURE);
0245     }
0246 
0247     st->msg = *resp;
0248     list_add_tail(&st->list, sockets);
0249 }
0250 
0251 /*
0252  * Read vsock stats into a list.
0253  */
0254 static void read_vsock_stat(struct list_head *sockets)
0255 {
0256     long buf[8192 / sizeof(long)];
0257     int fd;
0258 
0259     fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
0260     if (fd < 0) {
0261         perror("socket");
0262         exit(EXIT_FAILURE);
0263     }
0264 
0265     send_req(fd);
0266 
0267     for (;;) {
0268         const struct nlmsghdr *h;
0269         ssize_t ret;
0270 
0271         ret = recv_resp(fd, buf, sizeof(buf));
0272         if (ret == 0)
0273             goto done;
0274         if (ret < sizeof(*h)) {
0275             fprintf(stderr, "short read of %zd bytes\n", ret);
0276             exit(EXIT_FAILURE);
0277         }
0278 
0279         h = (struct nlmsghdr *)buf;
0280 
0281         while (NLMSG_OK(h, ret)) {
0282             if (h->nlmsg_type == NLMSG_DONE)
0283                 goto done;
0284 
0285             if (h->nlmsg_type == NLMSG_ERROR) {
0286                 const struct nlmsgerr *err = NLMSG_DATA(h);
0287 
0288                 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
0289                     fprintf(stderr, "NLMSG_ERROR\n");
0290                 else {
0291                     errno = -err->error;
0292                     perror("NLMSG_ERROR");
0293                 }
0294 
0295                 exit(EXIT_FAILURE);
0296             }
0297 
0298             if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
0299                 fprintf(stderr, "unexpected nlmsg_type %#x\n",
0300                     h->nlmsg_type);
0301                 exit(EXIT_FAILURE);
0302             }
0303             if (h->nlmsg_len <
0304                 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
0305                 fprintf(stderr, "short vsock_diag_msg\n");
0306                 exit(EXIT_FAILURE);
0307             }
0308 
0309             add_vsock_stat(sockets, NLMSG_DATA(h));
0310 
0311             h = NLMSG_NEXT(h, ret);
0312         }
0313     }
0314 
0315 done:
0316     close(fd);
0317 }
0318 
0319 static void free_sock_stat(struct list_head *sockets)
0320 {
0321     struct vsock_stat *st;
0322     struct vsock_stat *next;
0323 
0324     list_for_each_entry_safe(st, next, sockets, list)
0325         free(st);
0326 }
0327 
0328 static void test_no_sockets(const struct test_opts *opts)
0329 {
0330     LIST_HEAD(sockets);
0331 
0332     read_vsock_stat(&sockets);
0333 
0334     check_no_sockets(&sockets);
0335 }
0336 
0337 static void test_listen_socket_server(const struct test_opts *opts)
0338 {
0339     union {
0340         struct sockaddr sa;
0341         struct sockaddr_vm svm;
0342     } addr = {
0343         .svm = {
0344             .svm_family = AF_VSOCK,
0345             .svm_port = 1234,
0346             .svm_cid = VMADDR_CID_ANY,
0347         },
0348     };
0349     LIST_HEAD(sockets);
0350     struct vsock_stat *st;
0351     int fd;
0352 
0353     fd = socket(AF_VSOCK, SOCK_STREAM, 0);
0354 
0355     if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
0356         perror("bind");
0357         exit(EXIT_FAILURE);
0358     }
0359 
0360     if (listen(fd, 1) < 0) {
0361         perror("listen");
0362         exit(EXIT_FAILURE);
0363     }
0364 
0365     read_vsock_stat(&sockets);
0366 
0367     check_num_sockets(&sockets, 1);
0368     st = find_vsock_stat(&sockets, fd);
0369     check_socket_state(st, TCP_LISTEN);
0370 
0371     close(fd);
0372     free_sock_stat(&sockets);
0373 }
0374 
0375 static void test_connect_client(const struct test_opts *opts)
0376 {
0377     int fd;
0378     LIST_HEAD(sockets);
0379     struct vsock_stat *st;
0380 
0381     fd = vsock_stream_connect(opts->peer_cid, 1234);
0382     if (fd < 0) {
0383         perror("connect");
0384         exit(EXIT_FAILURE);
0385     }
0386 
0387     read_vsock_stat(&sockets);
0388 
0389     check_num_sockets(&sockets, 1);
0390     st = find_vsock_stat(&sockets, fd);
0391     check_socket_state(st, TCP_ESTABLISHED);
0392 
0393     control_expectln("DONE");
0394     control_writeln("DONE");
0395 
0396     close(fd);
0397     free_sock_stat(&sockets);
0398 }
0399 
0400 static void test_connect_server(const struct test_opts *opts)
0401 {
0402     struct vsock_stat *st;
0403     LIST_HEAD(sockets);
0404     int client_fd;
0405 
0406     client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
0407     if (client_fd < 0) {
0408         perror("accept");
0409         exit(EXIT_FAILURE);
0410     }
0411 
0412     read_vsock_stat(&sockets);
0413 
0414     check_num_sockets(&sockets, 1);
0415     st = find_vsock_stat(&sockets, client_fd);
0416     check_socket_state(st, TCP_ESTABLISHED);
0417 
0418     control_writeln("DONE");
0419     control_expectln("DONE");
0420 
0421     close(client_fd);
0422     free_sock_stat(&sockets);
0423 }
0424 
0425 static struct test_case test_cases[] = {
0426     {
0427         .name = "No sockets",
0428         .run_server = test_no_sockets,
0429     },
0430     {
0431         .name = "Listen socket",
0432         .run_server = test_listen_socket_server,
0433     },
0434     {
0435         .name = "Connect",
0436         .run_client = test_connect_client,
0437         .run_server = test_connect_server,
0438     },
0439     {},
0440 };
0441 
0442 static const char optstring[] = "";
0443 static const struct option longopts[] = {
0444     {
0445         .name = "control-host",
0446         .has_arg = required_argument,
0447         .val = 'H',
0448     },
0449     {
0450         .name = "control-port",
0451         .has_arg = required_argument,
0452         .val = 'P',
0453     },
0454     {
0455         .name = "mode",
0456         .has_arg = required_argument,
0457         .val = 'm',
0458     },
0459     {
0460         .name = "peer-cid",
0461         .has_arg = required_argument,
0462         .val = 'p',
0463     },
0464     {
0465         .name = "list",
0466         .has_arg = no_argument,
0467         .val = 'l',
0468     },
0469     {
0470         .name = "skip",
0471         .has_arg = required_argument,
0472         .val = 's',
0473     },
0474     {
0475         .name = "help",
0476         .has_arg = no_argument,
0477         .val = '?',
0478     },
0479     {},
0480 };
0481 
0482 static void usage(void)
0483 {
0484     fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
0485         "\n"
0486         "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
0487         "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
0488         "\n"
0489         "Run vsock_diag.ko tests.  Must be launched in both\n"
0490         "guest and host.  One side must use --mode=client and\n"
0491         "the other side must use --mode=server.\n"
0492         "\n"
0493         "A TCP control socket connection is used to coordinate tests\n"
0494         "between the client and the server.  The server requires a\n"
0495         "listen address and the client requires an address to\n"
0496         "connect to.\n"
0497         "\n"
0498         "The CID of the other side must be given with --peer-cid=<cid>.\n"
0499         "\n"
0500         "Options:\n"
0501         "  --help                 This help message\n"
0502         "  --control-host <host>  Server IP address to connect to\n"
0503         "  --control-port <port>  Server port to listen on/connect to\n"
0504         "  --mode client|server   Server or client mode\n"
0505         "  --peer-cid <cid>       CID of the other side\n"
0506         "  --list                 List of tests that will be executed\n"
0507         "  --skip <test_id>       Test ID to skip;\n"
0508         "                         use multiple --skip options to skip more tests\n"
0509         );
0510     exit(EXIT_FAILURE);
0511 }
0512 
0513 int main(int argc, char **argv)
0514 {
0515     const char *control_host = NULL;
0516     const char *control_port = NULL;
0517     struct test_opts opts = {
0518         .mode = TEST_MODE_UNSET,
0519         .peer_cid = VMADDR_CID_ANY,
0520     };
0521 
0522     init_signals();
0523 
0524     for (;;) {
0525         int opt = getopt_long(argc, argv, optstring, longopts, NULL);
0526 
0527         if (opt == -1)
0528             break;
0529 
0530         switch (opt) {
0531         case 'H':
0532             control_host = optarg;
0533             break;
0534         case 'm':
0535             if (strcmp(optarg, "client") == 0)
0536                 opts.mode = TEST_MODE_CLIENT;
0537             else if (strcmp(optarg, "server") == 0)
0538                 opts.mode = TEST_MODE_SERVER;
0539             else {
0540                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
0541                 return EXIT_FAILURE;
0542             }
0543             break;
0544         case 'p':
0545             opts.peer_cid = parse_cid(optarg);
0546             break;
0547         case 'P':
0548             control_port = optarg;
0549             break;
0550         case 'l':
0551             list_tests(test_cases);
0552             break;
0553         case 's':
0554             skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
0555                   optarg);
0556             break;
0557         case '?':
0558         default:
0559             usage();
0560         }
0561     }
0562 
0563     if (!control_port)
0564         usage();
0565     if (opts.mode == TEST_MODE_UNSET)
0566         usage();
0567     if (opts.peer_cid == VMADDR_CID_ANY)
0568         usage();
0569 
0570     if (!control_host) {
0571         if (opts.mode != TEST_MODE_SERVER)
0572             usage();
0573         control_host = "0.0.0.0";
0574     }
0575 
0576     control_init(control_host, control_port,
0577              opts.mode == TEST_MODE_SERVER);
0578 
0579     run_tests(test_cases, &opts);
0580 
0581     control_cleanup();
0582     return EXIT_SUCCESS;
0583 }