Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 
0003 #define _GNU_SOURCE
0004 
0005 #include <assert.h>
0006 #include <errno.h>
0007 #include <fcntl.h>
0008 #include <limits.h>
0009 #include <string.h>
0010 #include <stdarg.h>
0011 #include <stdbool.h>
0012 #include <stdint.h>
0013 #include <inttypes.h>
0014 #include <stdio.h>
0015 #include <stdlib.h>
0016 #include <strings.h>
0017 #include <unistd.h>
0018 #include <time.h>
0019 
0020 #include <sys/ioctl.h>
0021 #include <sys/socket.h>
0022 #include <sys/types.h>
0023 #include <sys/wait.h>
0024 
0025 #include <netdb.h>
0026 #include <netinet/in.h>
0027 
0028 #include <linux/tcp.h>
0029 #include <linux/sockios.h>
0030 
0031 #ifndef IPPROTO_MPTCP
0032 #define IPPROTO_MPTCP 262
0033 #endif
0034 #ifndef SOL_MPTCP
0035 #define SOL_MPTCP 284
0036 #endif
0037 
0038 static int pf = AF_INET;
0039 static int proto_tx = IPPROTO_MPTCP;
0040 static int proto_rx = IPPROTO_MPTCP;
0041 
0042 static void die_perror(const char *msg)
0043 {
0044     perror(msg);
0045     exit(1);
0046 }
0047 
0048 static void die_usage(int r)
0049 {
0050     fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
0051     exit(r);
0052 }
0053 
0054 static void xerror(const char *fmt, ...)
0055 {
0056     va_list ap;
0057 
0058     va_start(ap, fmt);
0059     vfprintf(stderr, fmt, ap);
0060     va_end(ap);
0061     fputc('\n', stderr);
0062     exit(1);
0063 }
0064 
0065 static const char *getxinfo_strerr(int err)
0066 {
0067     if (err == EAI_SYSTEM)
0068         return strerror(errno);
0069 
0070     return gai_strerror(err);
0071 }
0072 
0073 static void xgetaddrinfo(const char *node, const char *service,
0074              const struct addrinfo *hints,
0075              struct addrinfo **res)
0076 {
0077     int err = getaddrinfo(node, service, hints, res);
0078 
0079     if (err) {
0080         const char *errstr = getxinfo_strerr(err);
0081 
0082         fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
0083             node ? node : "", service ? service : "", errstr);
0084         exit(1);
0085     }
0086 }
0087 
0088 static int sock_listen_mptcp(const char * const listenaddr,
0089                  const char * const port)
0090 {
0091     int sock = -1;
0092     struct addrinfo hints = {
0093         .ai_protocol = IPPROTO_TCP,
0094         .ai_socktype = SOCK_STREAM,
0095         .ai_flags = AI_PASSIVE | AI_NUMERICHOST
0096     };
0097 
0098     hints.ai_family = pf;
0099 
0100     struct addrinfo *a, *addr;
0101     int one = 1;
0102 
0103     xgetaddrinfo(listenaddr, port, &hints, &addr);
0104     hints.ai_family = pf;
0105 
0106     for (a = addr; a; a = a->ai_next) {
0107         sock = socket(a->ai_family, a->ai_socktype, proto_rx);
0108         if (sock < 0)
0109             continue;
0110 
0111         if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
0112                      sizeof(one)))
0113             perror("setsockopt");
0114 
0115         if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
0116             break; /* success */
0117 
0118         perror("bind");
0119         close(sock);
0120         sock = -1;
0121     }
0122 
0123     freeaddrinfo(addr);
0124 
0125     if (sock < 0)
0126         xerror("could not create listen socket");
0127 
0128     if (listen(sock, 20))
0129         die_perror("listen");
0130 
0131     return sock;
0132 }
0133 
0134 static int sock_connect_mptcp(const char * const remoteaddr,
0135                   const char * const port, int proto)
0136 {
0137     struct addrinfo hints = {
0138         .ai_protocol = IPPROTO_TCP,
0139         .ai_socktype = SOCK_STREAM,
0140     };
0141     struct addrinfo *a, *addr;
0142     int sock = -1;
0143 
0144     hints.ai_family = pf;
0145 
0146     xgetaddrinfo(remoteaddr, port, &hints, &addr);
0147     for (a = addr; a; a = a->ai_next) {
0148         sock = socket(a->ai_family, a->ai_socktype, proto);
0149         if (sock < 0)
0150             continue;
0151 
0152         if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
0153             break; /* success */
0154 
0155         die_perror("connect");
0156     }
0157 
0158     if (sock < 0)
0159         xerror("could not create connect socket");
0160 
0161     freeaddrinfo(addr);
0162     return sock;
0163 }
0164 
0165 static int protostr_to_num(const char *s)
0166 {
0167     if (strcasecmp(s, "tcp") == 0)
0168         return IPPROTO_TCP;
0169     if (strcasecmp(s, "mptcp") == 0)
0170         return IPPROTO_MPTCP;
0171 
0172     die_usage(1);
0173     return 0;
0174 }
0175 
0176 static void parse_opts(int argc, char **argv)
0177 {
0178     int c;
0179 
0180     while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
0181         switch (c) {
0182         case 'h':
0183             die_usage(0);
0184             break;
0185         case '6':
0186             pf = AF_INET6;
0187             break;
0188         case 't':
0189             proto_tx = protostr_to_num(optarg);
0190             break;
0191         case 'r':
0192             proto_rx = protostr_to_num(optarg);
0193             break;
0194         default:
0195             die_usage(1);
0196             break;
0197         }
0198     }
0199 }
0200 
0201 /* wait up to timeout milliseconds */
0202 static void wait_for_ack(int fd, int timeout, size_t total)
0203 {
0204     int i;
0205 
0206     for (i = 0; i < timeout; i++) {
0207         int nsd, ret, queued = -1;
0208         struct timespec req;
0209 
0210         ret = ioctl(fd, TIOCOUTQ, &queued);
0211         if (ret < 0)
0212             die_perror("TIOCOUTQ");
0213 
0214         ret = ioctl(fd, SIOCOUTQNSD, &nsd);
0215         if (ret < 0)
0216             die_perror("SIOCOUTQNSD");
0217 
0218         if ((size_t)queued > total)
0219             xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
0220         assert(nsd <= queued);
0221 
0222         if (queued == 0)
0223             return;
0224 
0225         /* wait for peer to ack rx of all data */
0226         req.tv_sec = 0;
0227         req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
0228         nanosleep(&req, NULL);
0229     }
0230 
0231     xerror("still tx data queued after %u ms\n", timeout);
0232 }
0233 
0234 static void connect_one_server(int fd, int unixfd)
0235 {
0236     size_t len, i, total, sent;
0237     char buf[4096], buf2[4096];
0238     ssize_t ret;
0239 
0240     len = rand() % (sizeof(buf) - 1);
0241 
0242     if (len < 128)
0243         len = 128;
0244 
0245     for (i = 0; i < len ; i++) {
0246         buf[i] = rand() % 26;
0247         buf[i] += 'A';
0248     }
0249 
0250     buf[i] = '\n';
0251 
0252     /* un-block server */
0253     ret = read(unixfd, buf2, 4);
0254     assert(ret == 4);
0255 
0256     assert(strncmp(buf2, "xmit", 4) == 0);
0257 
0258     ret = write(unixfd, &len, sizeof(len));
0259     assert(ret == (ssize_t)sizeof(len));
0260 
0261     ret = write(fd, buf, len);
0262     if (ret < 0)
0263         die_perror("write");
0264 
0265     if (ret != (ssize_t)len)
0266         xerror("short write");
0267 
0268     ret = read(unixfd, buf2, 4);
0269     assert(strncmp(buf2, "huge", 4) == 0);
0270 
0271     total = rand() % (16 * 1024 * 1024);
0272     total += (1 * 1024 * 1024);
0273     sent = total;
0274 
0275     ret = write(unixfd, &total, sizeof(total));
0276     assert(ret == (ssize_t)sizeof(total));
0277 
0278     wait_for_ack(fd, 5000, len);
0279 
0280     while (total > 0) {
0281         if (total > sizeof(buf))
0282             len = sizeof(buf);
0283         else
0284             len = total;
0285 
0286         ret = write(fd, buf, len);
0287         if (ret < 0)
0288             die_perror("write");
0289         total -= ret;
0290 
0291         /* we don't have to care about buf content, only
0292          * number of total bytes sent
0293          */
0294     }
0295 
0296     ret = read(unixfd, buf2, 4);
0297     assert(ret == 4);
0298     assert(strncmp(buf2, "shut", 4) == 0);
0299 
0300     wait_for_ack(fd, 5000, sent);
0301 
0302     ret = write(fd, buf, 1);
0303     assert(ret == 1);
0304     close(fd);
0305     ret = write(unixfd, "closed", 6);
0306     assert(ret == 6);
0307 
0308     close(unixfd);
0309 }
0310 
0311 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
0312 {
0313     struct cmsghdr *cmsg;
0314 
0315     for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
0316         if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
0317             memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
0318             return;
0319         }
0320     }
0321 
0322     xerror("could not find TCP_CM_INQ cmsg type");
0323 }
0324 
0325 static void process_one_client(int fd, int unixfd)
0326 {
0327     unsigned int tcp_inq;
0328     size_t expect_len;
0329     char msg_buf[4096];
0330     char buf[4096];
0331     char tmp[16];
0332     struct iovec iov = {
0333         .iov_base = buf,
0334         .iov_len = 1,
0335     };
0336     struct msghdr msg = {
0337         .msg_iov = &iov,
0338         .msg_iovlen = 1,
0339         .msg_control = msg_buf,
0340         .msg_controllen = sizeof(msg_buf),
0341     };
0342     ssize_t ret, tot;
0343 
0344     ret = write(unixfd, "xmit", 4);
0345     assert(ret == 4);
0346 
0347     ret = read(unixfd, &expect_len, sizeof(expect_len));
0348     assert(ret == (ssize_t)sizeof(expect_len));
0349 
0350     if (expect_len > sizeof(buf))
0351         xerror("expect len %zu exceeds buffer size", expect_len);
0352 
0353     for (;;) {
0354         struct timespec req;
0355         unsigned int queued;
0356 
0357         ret = ioctl(fd, FIONREAD, &queued);
0358         if (ret < 0)
0359             die_perror("FIONREAD");
0360         if (queued > expect_len)
0361             xerror("FIONREAD returned %u, but only %zu expected\n",
0362                    queued, expect_len);
0363         if (queued == expect_len)
0364             break;
0365 
0366         req.tv_sec = 0;
0367         req.tv_nsec = 1000 * 1000ul;
0368         nanosleep(&req, NULL);
0369     }
0370 
0371     /* read one byte, expect cmsg to return expected - 1 */
0372     ret = recvmsg(fd, &msg, 0);
0373     if (ret < 0)
0374         die_perror("recvmsg");
0375 
0376     if (msg.msg_controllen == 0)
0377         xerror("msg_controllen is 0");
0378 
0379     get_tcp_inq(&msg, &tcp_inq);
0380 
0381     assert((size_t)tcp_inq == (expect_len - 1));
0382 
0383     iov.iov_len = sizeof(buf);
0384     ret = recvmsg(fd, &msg, 0);
0385     if (ret < 0)
0386         die_perror("recvmsg");
0387 
0388     /* should have gotten exact remainder of all pending data */
0389     assert(ret == (ssize_t)tcp_inq);
0390 
0391     /* should be 0, all drained */
0392     get_tcp_inq(&msg, &tcp_inq);
0393     assert(tcp_inq == 0);
0394 
0395     /* request a large swath of data. */
0396     ret = write(unixfd, "huge", 4);
0397     assert(ret == 4);
0398 
0399     ret = read(unixfd, &expect_len, sizeof(expect_len));
0400     assert(ret == (ssize_t)sizeof(expect_len));
0401 
0402     /* peer should send us a few mb of data */
0403     if (expect_len <= sizeof(buf))
0404         xerror("expect len %zu too small\n", expect_len);
0405 
0406     tot = 0;
0407     do {
0408         iov.iov_len = sizeof(buf);
0409         ret = recvmsg(fd, &msg, 0);
0410         if (ret < 0)
0411             die_perror("recvmsg");
0412 
0413         tot += ret;
0414 
0415         get_tcp_inq(&msg, &tcp_inq);
0416 
0417         if (tcp_inq > expect_len - tot)
0418             xerror("inq %d, remaining %d total_len %d\n",
0419                    tcp_inq, expect_len - tot, (int)expect_len);
0420 
0421         assert(tcp_inq <= expect_len - tot);
0422     } while ((size_t)tot < expect_len);
0423 
0424     ret = write(unixfd, "shut", 4);
0425     assert(ret == 4);
0426 
0427     /* wait for hangup. Should have received one more byte of data. */
0428     ret = read(unixfd, tmp, sizeof(tmp));
0429     assert(ret == 6);
0430     assert(strncmp(tmp, "closed", 6) == 0);
0431 
0432     sleep(1);
0433 
0434     iov.iov_len = 1;
0435     ret = recvmsg(fd, &msg, 0);
0436     if (ret < 0)
0437         die_perror("recvmsg");
0438     assert(ret == 1);
0439 
0440     get_tcp_inq(&msg, &tcp_inq);
0441 
0442     /* tcp_inq should be 1 due to received fin. */
0443     assert(tcp_inq == 1);
0444 
0445     iov.iov_len = 1;
0446     ret = recvmsg(fd, &msg, 0);
0447     if (ret < 0)
0448         die_perror("recvmsg");
0449 
0450     /* expect EOF */
0451     assert(ret == 0);
0452     get_tcp_inq(&msg, &tcp_inq);
0453     assert(tcp_inq == 1);
0454 
0455     close(fd);
0456 }
0457 
0458 static int xaccept(int s)
0459 {
0460     int fd = accept(s, NULL, 0);
0461 
0462     if (fd < 0)
0463         die_perror("accept");
0464 
0465     return fd;
0466 }
0467 
0468 static int server(int unixfd)
0469 {
0470     int fd = -1, r, on = 1;
0471 
0472     switch (pf) {
0473     case AF_INET:
0474         fd = sock_listen_mptcp("127.0.0.1", "15432");
0475         break;
0476     case AF_INET6:
0477         fd = sock_listen_mptcp("::1", "15432");
0478         break;
0479     default:
0480         xerror("Unknown pf %d\n", pf);
0481         break;
0482     }
0483 
0484     r = write(unixfd, "conn", 4);
0485     assert(r == 4);
0486 
0487     alarm(15);
0488     r = xaccept(fd);
0489 
0490     if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
0491         die_perror("setsockopt");
0492 
0493     process_one_client(r, unixfd);
0494 
0495     return 0;
0496 }
0497 
0498 static int client(int unixfd)
0499 {
0500     int fd = -1;
0501 
0502     alarm(15);
0503 
0504     switch (pf) {
0505     case AF_INET:
0506         fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
0507         break;
0508     case AF_INET6:
0509         fd = sock_connect_mptcp("::1", "15432", proto_tx);
0510         break;
0511     default:
0512         xerror("Unknown pf %d\n", pf);
0513     }
0514 
0515     connect_one_server(fd, unixfd);
0516 
0517     return 0;
0518 }
0519 
0520 static void init_rng(void)
0521 {
0522     int fd = open("/dev/urandom", O_RDONLY);
0523     unsigned int foo;
0524 
0525     if (fd > 0) {
0526         int ret = read(fd, &foo, sizeof(foo));
0527 
0528         if (ret < 0)
0529             srand(fd + foo);
0530         close(fd);
0531     }
0532 
0533     srand(foo);
0534 }
0535 
0536 static pid_t xfork(void)
0537 {
0538     pid_t p = fork();
0539 
0540     if (p < 0)
0541         die_perror("fork");
0542     else if (p == 0)
0543         init_rng();
0544 
0545     return p;
0546 }
0547 
0548 static int rcheck(int wstatus, const char *what)
0549 {
0550     if (WIFEXITED(wstatus)) {
0551         if (WEXITSTATUS(wstatus) == 0)
0552             return 0;
0553         fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
0554         return WEXITSTATUS(wstatus);
0555     } else if (WIFSIGNALED(wstatus)) {
0556         xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
0557     } else if (WIFSTOPPED(wstatus)) {
0558         xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
0559     }
0560 
0561     return 111;
0562 }
0563 
0564 int main(int argc, char *argv[])
0565 {
0566     int e1, e2, wstatus;
0567     pid_t s, c, ret;
0568     int unixfds[2];
0569 
0570     parse_opts(argc, argv);
0571 
0572     e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
0573     if (e1 < 0)
0574         die_perror("pipe");
0575 
0576     s = xfork();
0577     if (s == 0)
0578         return server(unixfds[1]);
0579 
0580     close(unixfds[1]);
0581 
0582     /* wait until server bound a socket */
0583     e1 = read(unixfds[0], &e1, 4);
0584     assert(e1 == 4);
0585 
0586     c = xfork();
0587     if (c == 0)
0588         return client(unixfds[0]);
0589 
0590     close(unixfds[0]);
0591 
0592     ret = waitpid(s, &wstatus, 0);
0593     if (ret == -1)
0594         die_perror("waitpid");
0595     e1 = rcheck(wstatus, "server");
0596     ret = waitpid(c, &wstatus, 0);
0597     if (ret == -1)
0598         die_perror("waitpid");
0599     e2 = rcheck(wstatus, "client");
0600 
0601     return e1 ? e1 : e2;
0602 }