Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * vsock test utilities
0004  *
0005  * Copyright (C) 2017 Red Hat, Inc.
0006  *
0007  * Author: Stefan Hajnoczi <stefanha@redhat.com>
0008  */
0009 
0010 #include <errno.h>
0011 #include <stdio.h>
0012 #include <stdint.h>
0013 #include <stdlib.h>
0014 #include <signal.h>
0015 #include <unistd.h>
0016 #include <assert.h>
0017 #include <sys/epoll.h>
0018 
0019 #include "timeout.h"
0020 #include "control.h"
0021 #include "util.h"
0022 
0023 /* Install signal handlers */
0024 void init_signals(void)
0025 {
0026     struct sigaction act = {
0027         .sa_handler = sigalrm,
0028     };
0029 
0030     sigaction(SIGALRM, &act, NULL);
0031     signal(SIGPIPE, SIG_IGN);
0032 }
0033 
0034 /* Parse a CID in string representation */
0035 unsigned int parse_cid(const char *str)
0036 {
0037     char *endptr = NULL;
0038     unsigned long n;
0039 
0040     errno = 0;
0041     n = strtoul(str, &endptr, 10);
0042     if (errno || *endptr != '\0') {
0043         fprintf(stderr, "malformed CID \"%s\"\n", str);
0044         exit(EXIT_FAILURE);
0045     }
0046     return n;
0047 }
0048 
0049 /* Wait for the remote to close the connection */
0050 void vsock_wait_remote_close(int fd)
0051 {
0052     struct epoll_event ev;
0053     int epollfd, nfds;
0054 
0055     epollfd = epoll_create1(0);
0056     if (epollfd == -1) {
0057         perror("epoll_create1");
0058         exit(EXIT_FAILURE);
0059     }
0060 
0061     ev.events = EPOLLRDHUP | EPOLLHUP;
0062     ev.data.fd = fd;
0063     if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
0064         perror("epoll_ctl");
0065         exit(EXIT_FAILURE);
0066     }
0067 
0068     nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
0069     if (nfds == -1) {
0070         perror("epoll_wait");
0071         exit(EXIT_FAILURE);
0072     }
0073 
0074     if (nfds == 0) {
0075         fprintf(stderr, "epoll_wait timed out\n");
0076         exit(EXIT_FAILURE);
0077     }
0078 
0079     assert(nfds == 1);
0080     assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
0081     assert(ev.data.fd == fd);
0082 
0083     close(epollfd);
0084 }
0085 
0086 /* Connect to <cid, port> and return the file descriptor. */
0087 static int vsock_connect(unsigned int cid, unsigned int port, int type)
0088 {
0089     union {
0090         struct sockaddr sa;
0091         struct sockaddr_vm svm;
0092     } addr = {
0093         .svm = {
0094             .svm_family = AF_VSOCK,
0095             .svm_port = port,
0096             .svm_cid = cid,
0097         },
0098     };
0099     int ret;
0100     int fd;
0101 
0102     control_expectln("LISTENING");
0103 
0104     fd = socket(AF_VSOCK, type, 0);
0105 
0106     timeout_begin(TIMEOUT);
0107     do {
0108         ret = connect(fd, &addr.sa, sizeof(addr.svm));
0109         timeout_check("connect");
0110     } while (ret < 0 && errno == EINTR);
0111     timeout_end();
0112 
0113     if (ret < 0) {
0114         int old_errno = errno;
0115 
0116         close(fd);
0117         fd = -1;
0118         errno = old_errno;
0119     }
0120     return fd;
0121 }
0122 
0123 int vsock_stream_connect(unsigned int cid, unsigned int port)
0124 {
0125     return vsock_connect(cid, port, SOCK_STREAM);
0126 }
0127 
0128 int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
0129 {
0130     return vsock_connect(cid, port, SOCK_SEQPACKET);
0131 }
0132 
0133 /* Listen on <cid, port> and return the first incoming connection.  The remote
0134  * address is stored to clientaddrp.  clientaddrp may be NULL.
0135  */
0136 static int vsock_accept(unsigned int cid, unsigned int port,
0137             struct sockaddr_vm *clientaddrp, int type)
0138 {
0139     union {
0140         struct sockaddr sa;
0141         struct sockaddr_vm svm;
0142     } addr = {
0143         .svm = {
0144             .svm_family = AF_VSOCK,
0145             .svm_port = port,
0146             .svm_cid = cid,
0147         },
0148     };
0149     union {
0150         struct sockaddr sa;
0151         struct sockaddr_vm svm;
0152     } clientaddr;
0153     socklen_t clientaddr_len = sizeof(clientaddr.svm);
0154     int fd;
0155     int client_fd;
0156     int old_errno;
0157 
0158     fd = socket(AF_VSOCK, type, 0);
0159 
0160     if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
0161         perror("bind");
0162         exit(EXIT_FAILURE);
0163     }
0164 
0165     if (listen(fd, 1) < 0) {
0166         perror("listen");
0167         exit(EXIT_FAILURE);
0168     }
0169 
0170     control_writeln("LISTENING");
0171 
0172     timeout_begin(TIMEOUT);
0173     do {
0174         client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
0175         timeout_check("accept");
0176     } while (client_fd < 0 && errno == EINTR);
0177     timeout_end();
0178 
0179     old_errno = errno;
0180     close(fd);
0181     errno = old_errno;
0182 
0183     if (client_fd < 0)
0184         return client_fd;
0185 
0186     if (clientaddr_len != sizeof(clientaddr.svm)) {
0187         fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
0188             (size_t)clientaddr_len);
0189         exit(EXIT_FAILURE);
0190     }
0191     if (clientaddr.sa.sa_family != AF_VSOCK) {
0192         fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
0193             clientaddr.sa.sa_family);
0194         exit(EXIT_FAILURE);
0195     }
0196 
0197     if (clientaddrp)
0198         *clientaddrp = clientaddr.svm;
0199     return client_fd;
0200 }
0201 
0202 int vsock_stream_accept(unsigned int cid, unsigned int port,
0203             struct sockaddr_vm *clientaddrp)
0204 {
0205     return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
0206 }
0207 
0208 int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
0209                struct sockaddr_vm *clientaddrp)
0210 {
0211     return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
0212 }
0213 
0214 /* Transmit one byte and check the return value.
0215  *
0216  * expected_ret:
0217  *  <0 Negative errno (for testing errors)
0218  *   0 End-of-file
0219  *   1 Success
0220  */
0221 void send_byte(int fd, int expected_ret, int flags)
0222 {
0223     const uint8_t byte = 'A';
0224     ssize_t nwritten;
0225 
0226     timeout_begin(TIMEOUT);
0227     do {
0228         nwritten = send(fd, &byte, sizeof(byte), flags);
0229         timeout_check("write");
0230     } while (nwritten < 0 && errno == EINTR);
0231     timeout_end();
0232 
0233     if (expected_ret < 0) {
0234         if (nwritten != -1) {
0235             fprintf(stderr, "bogus send(2) return value %zd\n",
0236                 nwritten);
0237             exit(EXIT_FAILURE);
0238         }
0239         if (errno != -expected_ret) {
0240             perror("write");
0241             exit(EXIT_FAILURE);
0242         }
0243         return;
0244     }
0245 
0246     if (nwritten < 0) {
0247         perror("write");
0248         exit(EXIT_FAILURE);
0249     }
0250     if (nwritten == 0) {
0251         if (expected_ret == 0)
0252             return;
0253 
0254         fprintf(stderr, "unexpected EOF while sending byte\n");
0255         exit(EXIT_FAILURE);
0256     }
0257     if (nwritten != sizeof(byte)) {
0258         fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
0259         exit(EXIT_FAILURE);
0260     }
0261 }
0262 
0263 /* Receive one byte and check the return value.
0264  *
0265  * expected_ret:
0266  *  <0 Negative errno (for testing errors)
0267  *   0 End-of-file
0268  *   1 Success
0269  */
0270 void recv_byte(int fd, int expected_ret, int flags)
0271 {
0272     uint8_t byte;
0273     ssize_t nread;
0274 
0275     timeout_begin(TIMEOUT);
0276     do {
0277         nread = recv(fd, &byte, sizeof(byte), flags);
0278         timeout_check("read");
0279     } while (nread < 0 && errno == EINTR);
0280     timeout_end();
0281 
0282     if (expected_ret < 0) {
0283         if (nread != -1) {
0284             fprintf(stderr, "bogus recv(2) return value %zd\n",
0285                 nread);
0286             exit(EXIT_FAILURE);
0287         }
0288         if (errno != -expected_ret) {
0289             perror("read");
0290             exit(EXIT_FAILURE);
0291         }
0292         return;
0293     }
0294 
0295     if (nread < 0) {
0296         perror("read");
0297         exit(EXIT_FAILURE);
0298     }
0299     if (nread == 0) {
0300         if (expected_ret == 0)
0301             return;
0302 
0303         fprintf(stderr, "unexpected EOF while receiving byte\n");
0304         exit(EXIT_FAILURE);
0305     }
0306     if (nread != sizeof(byte)) {
0307         fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
0308         exit(EXIT_FAILURE);
0309     }
0310     if (byte != 'A') {
0311         fprintf(stderr, "unexpected byte read %c\n", byte);
0312         exit(EXIT_FAILURE);
0313     }
0314 }
0315 
0316 /* Run test cases.  The program terminates if a failure occurs. */
0317 void run_tests(const struct test_case *test_cases,
0318            const struct test_opts *opts)
0319 {
0320     int i;
0321 
0322     for (i = 0; test_cases[i].name; i++) {
0323         void (*run)(const struct test_opts *opts);
0324         char *line;
0325 
0326         printf("%d - %s...", i, test_cases[i].name);
0327         fflush(stdout);
0328 
0329         /* Full barrier before executing the next test.  This
0330          * ensures that client and server are executing the
0331          * same test case.  In particular, it means whoever is
0332          * faster will not see the peer still executing the
0333          * last test.  This is important because port numbers
0334          * can be used by multiple test cases.
0335          */
0336         if (test_cases[i].skip)
0337             control_writeln("SKIP");
0338         else
0339             control_writeln("NEXT");
0340 
0341         line = control_readln();
0342         if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
0343 
0344             printf("skipped\n");
0345 
0346             free(line);
0347             continue;
0348         }
0349 
0350         control_cmpln(line, "NEXT", true);
0351         free(line);
0352 
0353         if (opts->mode == TEST_MODE_CLIENT)
0354             run = test_cases[i].run_client;
0355         else
0356             run = test_cases[i].run_server;
0357 
0358         if (run)
0359             run(opts);
0360 
0361         printf("ok\n");
0362     }
0363 }
0364 
0365 void list_tests(const struct test_case *test_cases)
0366 {
0367     int i;
0368 
0369     printf("ID\tTest name\n");
0370 
0371     for (i = 0; test_cases[i].name; i++)
0372         printf("%d\t%s\n", i, test_cases[i].name);
0373 
0374     exit(EXIT_FAILURE);
0375 }
0376 
0377 void skip_test(struct test_case *test_cases, size_t test_cases_len,
0378            const char *test_id_str)
0379 {
0380     unsigned long test_id;
0381     char *endptr = NULL;
0382 
0383     errno = 0;
0384     test_id = strtoul(test_id_str, &endptr, 10);
0385     if (errno || *endptr != '\0') {
0386         fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
0387         exit(EXIT_FAILURE);
0388     }
0389 
0390     if (test_id >= test_cases_len) {
0391         fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
0392             test_id, test_cases_len - 1);
0393         exit(EXIT_FAILURE);
0394     }
0395 
0396     test_cases[test_id].skip = true;
0397 }