0001
0002
0003 #define _GNU_SOURCE
0004
0005 #include <assert.h>
0006 #include <errno.h>
0007 #include <fcntl.h>
0008 #include <limits.h>
0009 #include <string.h>
0010 #include <stdarg.h>
0011 #include <stdbool.h>
0012 #include <stdint.h>
0013 #include <inttypes.h>
0014 #include <stdio.h>
0015 #include <stdlib.h>
0016 #include <strings.h>
0017 #include <unistd.h>
0018 #include <time.h>
0019
0020 #include <sys/ioctl.h>
0021 #include <sys/socket.h>
0022 #include <sys/types.h>
0023 #include <sys/wait.h>
0024
0025 #include <netdb.h>
0026 #include <netinet/in.h>
0027
0028 #include <linux/tcp.h>
0029 #include <linux/sockios.h>
0030
0031 #ifndef IPPROTO_MPTCP
0032 #define IPPROTO_MPTCP 262
0033 #endif
0034 #ifndef SOL_MPTCP
0035 #define SOL_MPTCP 284
0036 #endif
0037
0038 static int pf = AF_INET;
0039 static int proto_tx = IPPROTO_MPTCP;
0040 static int proto_rx = IPPROTO_MPTCP;
0041
0042 static void die_perror(const char *msg)
0043 {
0044 perror(msg);
0045 exit(1);
0046 }
0047
0048 static void die_usage(int r)
0049 {
0050 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
0051 exit(r);
0052 }
0053
0054 static void xerror(const char *fmt, ...)
0055 {
0056 va_list ap;
0057
0058 va_start(ap, fmt);
0059 vfprintf(stderr, fmt, ap);
0060 va_end(ap);
0061 fputc('\n', stderr);
0062 exit(1);
0063 }
0064
0065 static const char *getxinfo_strerr(int err)
0066 {
0067 if (err == EAI_SYSTEM)
0068 return strerror(errno);
0069
0070 return gai_strerror(err);
0071 }
0072
0073 static void xgetaddrinfo(const char *node, const char *service,
0074 const struct addrinfo *hints,
0075 struct addrinfo **res)
0076 {
0077 int err = getaddrinfo(node, service, hints, res);
0078
0079 if (err) {
0080 const char *errstr = getxinfo_strerr(err);
0081
0082 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
0083 node ? node : "", service ? service : "", errstr);
0084 exit(1);
0085 }
0086 }
0087
0088 static int sock_listen_mptcp(const char * const listenaddr,
0089 const char * const port)
0090 {
0091 int sock = -1;
0092 struct addrinfo hints = {
0093 .ai_protocol = IPPROTO_TCP,
0094 .ai_socktype = SOCK_STREAM,
0095 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
0096 };
0097
0098 hints.ai_family = pf;
0099
0100 struct addrinfo *a, *addr;
0101 int one = 1;
0102
0103 xgetaddrinfo(listenaddr, port, &hints, &addr);
0104 hints.ai_family = pf;
0105
0106 for (a = addr; a; a = a->ai_next) {
0107 sock = socket(a->ai_family, a->ai_socktype, proto_rx);
0108 if (sock < 0)
0109 continue;
0110
0111 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
0112 sizeof(one)))
0113 perror("setsockopt");
0114
0115 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
0116 break;
0117
0118 perror("bind");
0119 close(sock);
0120 sock = -1;
0121 }
0122
0123 freeaddrinfo(addr);
0124
0125 if (sock < 0)
0126 xerror("could not create listen socket");
0127
0128 if (listen(sock, 20))
0129 die_perror("listen");
0130
0131 return sock;
0132 }
0133
0134 static int sock_connect_mptcp(const char * const remoteaddr,
0135 const char * const port, int proto)
0136 {
0137 struct addrinfo hints = {
0138 .ai_protocol = IPPROTO_TCP,
0139 .ai_socktype = SOCK_STREAM,
0140 };
0141 struct addrinfo *a, *addr;
0142 int sock = -1;
0143
0144 hints.ai_family = pf;
0145
0146 xgetaddrinfo(remoteaddr, port, &hints, &addr);
0147 for (a = addr; a; a = a->ai_next) {
0148 sock = socket(a->ai_family, a->ai_socktype, proto);
0149 if (sock < 0)
0150 continue;
0151
0152 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
0153 break;
0154
0155 die_perror("connect");
0156 }
0157
0158 if (sock < 0)
0159 xerror("could not create connect socket");
0160
0161 freeaddrinfo(addr);
0162 return sock;
0163 }
0164
0165 static int protostr_to_num(const char *s)
0166 {
0167 if (strcasecmp(s, "tcp") == 0)
0168 return IPPROTO_TCP;
0169 if (strcasecmp(s, "mptcp") == 0)
0170 return IPPROTO_MPTCP;
0171
0172 die_usage(1);
0173 return 0;
0174 }
0175
0176 static void parse_opts(int argc, char **argv)
0177 {
0178 int c;
0179
0180 while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
0181 switch (c) {
0182 case 'h':
0183 die_usage(0);
0184 break;
0185 case '6':
0186 pf = AF_INET6;
0187 break;
0188 case 't':
0189 proto_tx = protostr_to_num(optarg);
0190 break;
0191 case 'r':
0192 proto_rx = protostr_to_num(optarg);
0193 break;
0194 default:
0195 die_usage(1);
0196 break;
0197 }
0198 }
0199 }
0200
0201
0202 static void wait_for_ack(int fd, int timeout, size_t total)
0203 {
0204 int i;
0205
0206 for (i = 0; i < timeout; i++) {
0207 int nsd, ret, queued = -1;
0208 struct timespec req;
0209
0210 ret = ioctl(fd, TIOCOUTQ, &queued);
0211 if (ret < 0)
0212 die_perror("TIOCOUTQ");
0213
0214 ret = ioctl(fd, SIOCOUTQNSD, &nsd);
0215 if (ret < 0)
0216 die_perror("SIOCOUTQNSD");
0217
0218 if ((size_t)queued > total)
0219 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
0220 assert(nsd <= queued);
0221
0222 if (queued == 0)
0223 return;
0224
0225
0226 req.tv_sec = 0;
0227 req.tv_nsec = 1 * 1000 * 1000ul;
0228 nanosleep(&req, NULL);
0229 }
0230
0231 xerror("still tx data queued after %u ms\n", timeout);
0232 }
0233
0234 static void connect_one_server(int fd, int unixfd)
0235 {
0236 size_t len, i, total, sent;
0237 char buf[4096], buf2[4096];
0238 ssize_t ret;
0239
0240 len = rand() % (sizeof(buf) - 1);
0241
0242 if (len < 128)
0243 len = 128;
0244
0245 for (i = 0; i < len ; i++) {
0246 buf[i] = rand() % 26;
0247 buf[i] += 'A';
0248 }
0249
0250 buf[i] = '\n';
0251
0252
0253 ret = read(unixfd, buf2, 4);
0254 assert(ret == 4);
0255
0256 assert(strncmp(buf2, "xmit", 4) == 0);
0257
0258 ret = write(unixfd, &len, sizeof(len));
0259 assert(ret == (ssize_t)sizeof(len));
0260
0261 ret = write(fd, buf, len);
0262 if (ret < 0)
0263 die_perror("write");
0264
0265 if (ret != (ssize_t)len)
0266 xerror("short write");
0267
0268 ret = read(unixfd, buf2, 4);
0269 assert(strncmp(buf2, "huge", 4) == 0);
0270
0271 total = rand() % (16 * 1024 * 1024);
0272 total += (1 * 1024 * 1024);
0273 sent = total;
0274
0275 ret = write(unixfd, &total, sizeof(total));
0276 assert(ret == (ssize_t)sizeof(total));
0277
0278 wait_for_ack(fd, 5000, len);
0279
0280 while (total > 0) {
0281 if (total > sizeof(buf))
0282 len = sizeof(buf);
0283 else
0284 len = total;
0285
0286 ret = write(fd, buf, len);
0287 if (ret < 0)
0288 die_perror("write");
0289 total -= ret;
0290
0291
0292
0293
0294 }
0295
0296 ret = read(unixfd, buf2, 4);
0297 assert(ret == 4);
0298 assert(strncmp(buf2, "shut", 4) == 0);
0299
0300 wait_for_ack(fd, 5000, sent);
0301
0302 ret = write(fd, buf, 1);
0303 assert(ret == 1);
0304 close(fd);
0305 ret = write(unixfd, "closed", 6);
0306 assert(ret == 6);
0307
0308 close(unixfd);
0309 }
0310
0311 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
0312 {
0313 struct cmsghdr *cmsg;
0314
0315 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
0316 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
0317 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
0318 return;
0319 }
0320 }
0321
0322 xerror("could not find TCP_CM_INQ cmsg type");
0323 }
0324
0325 static void process_one_client(int fd, int unixfd)
0326 {
0327 unsigned int tcp_inq;
0328 size_t expect_len;
0329 char msg_buf[4096];
0330 char buf[4096];
0331 char tmp[16];
0332 struct iovec iov = {
0333 .iov_base = buf,
0334 .iov_len = 1,
0335 };
0336 struct msghdr msg = {
0337 .msg_iov = &iov,
0338 .msg_iovlen = 1,
0339 .msg_control = msg_buf,
0340 .msg_controllen = sizeof(msg_buf),
0341 };
0342 ssize_t ret, tot;
0343
0344 ret = write(unixfd, "xmit", 4);
0345 assert(ret == 4);
0346
0347 ret = read(unixfd, &expect_len, sizeof(expect_len));
0348 assert(ret == (ssize_t)sizeof(expect_len));
0349
0350 if (expect_len > sizeof(buf))
0351 xerror("expect len %zu exceeds buffer size", expect_len);
0352
0353 for (;;) {
0354 struct timespec req;
0355 unsigned int queued;
0356
0357 ret = ioctl(fd, FIONREAD, &queued);
0358 if (ret < 0)
0359 die_perror("FIONREAD");
0360 if (queued > expect_len)
0361 xerror("FIONREAD returned %u, but only %zu expected\n",
0362 queued, expect_len);
0363 if (queued == expect_len)
0364 break;
0365
0366 req.tv_sec = 0;
0367 req.tv_nsec = 1000 * 1000ul;
0368 nanosleep(&req, NULL);
0369 }
0370
0371
0372 ret = recvmsg(fd, &msg, 0);
0373 if (ret < 0)
0374 die_perror("recvmsg");
0375
0376 if (msg.msg_controllen == 0)
0377 xerror("msg_controllen is 0");
0378
0379 get_tcp_inq(&msg, &tcp_inq);
0380
0381 assert((size_t)tcp_inq == (expect_len - 1));
0382
0383 iov.iov_len = sizeof(buf);
0384 ret = recvmsg(fd, &msg, 0);
0385 if (ret < 0)
0386 die_perror("recvmsg");
0387
0388
0389 assert(ret == (ssize_t)tcp_inq);
0390
0391
0392 get_tcp_inq(&msg, &tcp_inq);
0393 assert(tcp_inq == 0);
0394
0395
0396 ret = write(unixfd, "huge", 4);
0397 assert(ret == 4);
0398
0399 ret = read(unixfd, &expect_len, sizeof(expect_len));
0400 assert(ret == (ssize_t)sizeof(expect_len));
0401
0402
0403 if (expect_len <= sizeof(buf))
0404 xerror("expect len %zu too small\n", expect_len);
0405
0406 tot = 0;
0407 do {
0408 iov.iov_len = sizeof(buf);
0409 ret = recvmsg(fd, &msg, 0);
0410 if (ret < 0)
0411 die_perror("recvmsg");
0412
0413 tot += ret;
0414
0415 get_tcp_inq(&msg, &tcp_inq);
0416
0417 if (tcp_inq > expect_len - tot)
0418 xerror("inq %d, remaining %d total_len %d\n",
0419 tcp_inq, expect_len - tot, (int)expect_len);
0420
0421 assert(tcp_inq <= expect_len - tot);
0422 } while ((size_t)tot < expect_len);
0423
0424 ret = write(unixfd, "shut", 4);
0425 assert(ret == 4);
0426
0427
0428 ret = read(unixfd, tmp, sizeof(tmp));
0429 assert(ret == 6);
0430 assert(strncmp(tmp, "closed", 6) == 0);
0431
0432 sleep(1);
0433
0434 iov.iov_len = 1;
0435 ret = recvmsg(fd, &msg, 0);
0436 if (ret < 0)
0437 die_perror("recvmsg");
0438 assert(ret == 1);
0439
0440 get_tcp_inq(&msg, &tcp_inq);
0441
0442
0443 assert(tcp_inq == 1);
0444
0445 iov.iov_len = 1;
0446 ret = recvmsg(fd, &msg, 0);
0447 if (ret < 0)
0448 die_perror("recvmsg");
0449
0450
0451 assert(ret == 0);
0452 get_tcp_inq(&msg, &tcp_inq);
0453 assert(tcp_inq == 1);
0454
0455 close(fd);
0456 }
0457
0458 static int xaccept(int s)
0459 {
0460 int fd = accept(s, NULL, 0);
0461
0462 if (fd < 0)
0463 die_perror("accept");
0464
0465 return fd;
0466 }
0467
0468 static int server(int unixfd)
0469 {
0470 int fd = -1, r, on = 1;
0471
0472 switch (pf) {
0473 case AF_INET:
0474 fd = sock_listen_mptcp("127.0.0.1", "15432");
0475 break;
0476 case AF_INET6:
0477 fd = sock_listen_mptcp("::1", "15432");
0478 break;
0479 default:
0480 xerror("Unknown pf %d\n", pf);
0481 break;
0482 }
0483
0484 r = write(unixfd, "conn", 4);
0485 assert(r == 4);
0486
0487 alarm(15);
0488 r = xaccept(fd);
0489
0490 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
0491 die_perror("setsockopt");
0492
0493 process_one_client(r, unixfd);
0494
0495 return 0;
0496 }
0497
0498 static int client(int unixfd)
0499 {
0500 int fd = -1;
0501
0502 alarm(15);
0503
0504 switch (pf) {
0505 case AF_INET:
0506 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
0507 break;
0508 case AF_INET6:
0509 fd = sock_connect_mptcp("::1", "15432", proto_tx);
0510 break;
0511 default:
0512 xerror("Unknown pf %d\n", pf);
0513 }
0514
0515 connect_one_server(fd, unixfd);
0516
0517 return 0;
0518 }
0519
0520 static void init_rng(void)
0521 {
0522 int fd = open("/dev/urandom", O_RDONLY);
0523 unsigned int foo;
0524
0525 if (fd > 0) {
0526 int ret = read(fd, &foo, sizeof(foo));
0527
0528 if (ret < 0)
0529 srand(fd + foo);
0530 close(fd);
0531 }
0532
0533 srand(foo);
0534 }
0535
0536 static pid_t xfork(void)
0537 {
0538 pid_t p = fork();
0539
0540 if (p < 0)
0541 die_perror("fork");
0542 else if (p == 0)
0543 init_rng();
0544
0545 return p;
0546 }
0547
0548 static int rcheck(int wstatus, const char *what)
0549 {
0550 if (WIFEXITED(wstatus)) {
0551 if (WEXITSTATUS(wstatus) == 0)
0552 return 0;
0553 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
0554 return WEXITSTATUS(wstatus);
0555 } else if (WIFSIGNALED(wstatus)) {
0556 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
0557 } else if (WIFSTOPPED(wstatus)) {
0558 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
0559 }
0560
0561 return 111;
0562 }
0563
0564 int main(int argc, char *argv[])
0565 {
0566 int e1, e2, wstatus;
0567 pid_t s, c, ret;
0568 int unixfds[2];
0569
0570 parse_opts(argc, argv);
0571
0572 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
0573 if (e1 < 0)
0574 die_perror("pipe");
0575
0576 s = xfork();
0577 if (s == 0)
0578 return server(unixfds[1]);
0579
0580 close(unixfds[1]);
0581
0582
0583 e1 = read(unixfds[0], &e1, 4);
0584 assert(e1 == 4);
0585
0586 c = xfork();
0587 if (c == 0)
0588 return client(unixfds[0]);
0589
0590 close(unixfds[0]);
0591
0592 ret = waitpid(s, &wstatus, 0);
0593 if (ret == -1)
0594 die_perror("waitpid");
0595 e1 = rcheck(wstatus, "server");
0596 ret = waitpid(c, &wstatus, 0);
0597 if (ret == -1)
0598 die_perror("waitpid");
0599 e2 = rcheck(wstatus, "client");
0600
0601 return e1 ? e1 : e2;
0602 }