Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 
0003 /*
0004  * Test key rotation for TFO.
0005  * New keys are 'rotated' in two steps:
0006  * 1) Add new key as the 'backup' key 'behind' the primary key
0007  * 2) Make new key the primary by swapping the backup and primary keys
0008  *
0009  * The rotation is done in stages using multiple sockets bound
0010  * to the same port via SO_REUSEPORT. This simulates key rotation
0011  * behind say a load balancer. We verify that across the rotation
0012  * there are no cases in which a cookie is not accepted by verifying
0013  * that TcpExtTCPFastOpenPassiveFail remains 0.
0014  */
0015 #define _GNU_SOURCE
0016 #include <arpa/inet.h>
0017 #include <errno.h>
0018 #include <error.h>
0019 #include <stdbool.h>
0020 #include <stdio.h>
0021 #include <stdlib.h>
0022 #include <string.h>
0023 #include <sys/epoll.h>
0024 #include <unistd.h>
0025 #include <netinet/tcp.h>
0026 #include <fcntl.h>
0027 #include <time.h>
0028 
0029 #include "../kselftest.h"
0030 
0031 #ifndef TCP_FASTOPEN_KEY
0032 #define TCP_FASTOPEN_KEY 33
0033 #endif
0034 
0035 #define N_LISTEN 10
0036 #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key"
0037 #define KEY_LENGTH 16
0038 
0039 static bool do_ipv6;
0040 static bool do_sockopt;
0041 static bool do_rotate;
0042 static int key_len = KEY_LENGTH;
0043 static int rcv_fds[N_LISTEN];
0044 static int proc_fd;
0045 static const char *IP4_ADDR = "127.0.0.1";
0046 static const char *IP6_ADDR = "::1";
0047 static const int PORT = 8891;
0048 
0049 static void get_keys(int fd, uint32_t *keys)
0050 {
0051     char buf[128];
0052     socklen_t len = KEY_LENGTH * 2;
0053 
0054     if (do_sockopt) {
0055         if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len))
0056             error(1, errno, "Unable to get key");
0057         return;
0058     }
0059     lseek(proc_fd, 0, SEEK_SET);
0060     if (read(proc_fd, buf, sizeof(buf)) <= 0)
0061         error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY);
0062     if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2,
0063         keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8)
0064         error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY);
0065 }
0066 
0067 static void set_keys(int fd, uint32_t *keys)
0068 {
0069     char buf[128];
0070 
0071     if (do_sockopt) {
0072         if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys,
0073             key_len))
0074             error(1, errno, "Unable to set key");
0075         return;
0076     }
0077     if (do_rotate)
0078         snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x",
0079              keys[0], keys[1], keys[2], keys[3], keys[4], keys[5],
0080              keys[6], keys[7]);
0081     else
0082         snprintf(buf, 128, "%08x-%08x-%08x-%08x",
0083              keys[0], keys[1], keys[2], keys[3]);
0084     lseek(proc_fd, 0, SEEK_SET);
0085     if (write(proc_fd, buf, sizeof(buf)) <= 0)
0086         error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY);
0087 }
0088 
0089 static void build_rcv_fd(int family, int proto, int *rcv_fds)
0090 {
0091     struct sockaddr_in  addr4 = {0};
0092     struct sockaddr_in6 addr6 = {0};
0093     struct sockaddr *addr;
0094     int opt = 1, i, sz;
0095     int qlen = 100;
0096     uint32_t keys[8];
0097 
0098     switch (family) {
0099     case AF_INET:
0100         addr4.sin_family = family;
0101         addr4.sin_addr.s_addr = htonl(INADDR_ANY);
0102         addr4.sin_port = htons(PORT);
0103         sz = sizeof(addr4);
0104         addr = (struct sockaddr *)&addr4;
0105         break;
0106     case AF_INET6:
0107         addr6.sin6_family = AF_INET6;
0108         addr6.sin6_addr = in6addr_any;
0109         addr6.sin6_port = htons(PORT);
0110         sz = sizeof(addr6);
0111         addr = (struct sockaddr *)&addr6;
0112         break;
0113     default:
0114         error(1, 0, "Unsupported family %d", family);
0115         /* clang does not recognize error() above as terminating
0116          * the program, so it complains that saddr, sz are
0117          * not initialized when this code path is taken. Silence it.
0118          */
0119         return;
0120     }
0121     for (i = 0; i < ARRAY_SIZE(keys); i++)
0122         keys[i] = rand();
0123     for (i = 0; i < N_LISTEN; i++) {
0124         rcv_fds[i] = socket(family, proto, 0);
0125         if (rcv_fds[i] < 0)
0126             error(1, errno, "failed to create receive socket");
0127         if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt,
0128                    sizeof(opt)))
0129             error(1, errno, "failed to set SO_REUSEPORT");
0130         if (bind(rcv_fds[i], addr, sz))
0131             error(1, errno, "failed to bind receive socket");
0132         if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen,
0133                    sizeof(qlen)))
0134             error(1, errno, "failed to set TCP_FASTOPEN");
0135         set_keys(rcv_fds[i], keys);
0136         if (proto == SOCK_STREAM && listen(rcv_fds[i], 10))
0137             error(1, errno, "failed to listen on receive port");
0138     }
0139 }
0140 
0141 static int connect_and_send(int family, int proto)
0142 {
0143     struct sockaddr_in  saddr4 = {0};
0144     struct sockaddr_in  daddr4 = {0};
0145     struct sockaddr_in6 saddr6 = {0};
0146     struct sockaddr_in6 daddr6 = {0};
0147     struct sockaddr *saddr, *daddr;
0148     int fd, sz, ret;
0149     char data[1];
0150 
0151     switch (family) {
0152     case AF_INET:
0153         saddr4.sin_family = AF_INET;
0154         saddr4.sin_addr.s_addr = htonl(INADDR_ANY);
0155         saddr4.sin_port = 0;
0156 
0157         daddr4.sin_family = AF_INET;
0158         if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr))
0159             error(1, errno, "inet_pton failed: %s", IP4_ADDR);
0160         daddr4.sin_port = htons(PORT);
0161 
0162         sz = sizeof(saddr4);
0163         saddr = (struct sockaddr *)&saddr4;
0164         daddr = (struct sockaddr *)&daddr4;
0165         break;
0166     case AF_INET6:
0167         saddr6.sin6_family = AF_INET6;
0168         saddr6.sin6_addr = in6addr_any;
0169 
0170         daddr6.sin6_family = AF_INET6;
0171         if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr))
0172             error(1, errno, "inet_pton failed: %s", IP6_ADDR);
0173         daddr6.sin6_port = htons(PORT);
0174 
0175         sz = sizeof(saddr6);
0176         saddr = (struct sockaddr *)&saddr6;
0177         daddr = (struct sockaddr *)&daddr6;
0178         break;
0179     default:
0180         error(1, 0, "Unsupported family %d", family);
0181         /* clang does not recognize error() above as terminating
0182          * the program, so it complains that saddr, daddr, sz are
0183          * not initialized when this code path is taken. Silence it.
0184          */
0185         return -1;
0186     }
0187     fd = socket(family, proto, 0);
0188     if (fd < 0)
0189         error(1, errno, "failed to create send socket");
0190     if (bind(fd, saddr, sz))
0191         error(1, errno, "failed to bind send socket");
0192     data[0] = 'a';
0193     ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz);
0194     if (ret != 1)
0195         error(1, errno, "failed to sendto");
0196 
0197     return fd;
0198 }
0199 
0200 static bool is_listen_fd(int fd)
0201 {
0202     int i;
0203 
0204     for (i = 0; i < N_LISTEN; i++) {
0205         if (rcv_fds[i] == fd)
0206             return true;
0207     }
0208     return false;
0209 }
0210 
0211 static void rotate_key(int fd)
0212 {
0213     static int iter;
0214     static uint32_t new_key[4];
0215     uint32_t keys[8];
0216     uint32_t tmp_key[4];
0217     int i;
0218 
0219     if (iter < N_LISTEN) {
0220         /* first set new key as backups */
0221         if (iter == 0) {
0222             for (i = 0; i < ARRAY_SIZE(new_key); i++)
0223                 new_key[i] = rand();
0224         }
0225         get_keys(fd, keys);
0226         memcpy(keys + 4, new_key, KEY_LENGTH);
0227         set_keys(fd, keys);
0228     } else {
0229         /* swap the keys */
0230         get_keys(fd, keys);
0231         memcpy(tmp_key, keys + 4, KEY_LENGTH);
0232         memcpy(keys + 4, keys, KEY_LENGTH);
0233         memcpy(keys, tmp_key, KEY_LENGTH);
0234         set_keys(fd, keys);
0235     }
0236     if (++iter >= (N_LISTEN * 2))
0237         iter = 0;
0238 }
0239 
0240 static void run_one_test(int family)
0241 {
0242     struct epoll_event ev;
0243     int i, send_fd;
0244     int n_loops = 10000;
0245     int rotate_key_fd = 0;
0246     int key_rotate_interval = 50;
0247     int fd, epfd;
0248     char buf[1];
0249 
0250     build_rcv_fd(family, SOCK_STREAM, rcv_fds);
0251     epfd = epoll_create(1);
0252     if (epfd < 0)
0253         error(1, errno, "failed to create epoll");
0254     ev.events = EPOLLIN;
0255     for (i = 0; i < N_LISTEN; i++) {
0256         ev.data.fd = rcv_fds[i];
0257         if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev))
0258             error(1, errno, "failed to register sock epoll");
0259     }
0260     while (n_loops--) {
0261         send_fd = connect_and_send(family, SOCK_STREAM);
0262         if (do_rotate && ((n_loops % key_rotate_interval) == 0)) {
0263             rotate_key(rcv_fds[rotate_key_fd]);
0264             if (++rotate_key_fd >= N_LISTEN)
0265                 rotate_key_fd = 0;
0266         }
0267         while (1) {
0268             i = epoll_wait(epfd, &ev, 1, -1);
0269             if (i < 0)
0270                 error(1, errno, "epoll_wait failed");
0271             if (is_listen_fd(ev.data.fd)) {
0272                 fd = accept(ev.data.fd, NULL, NULL);
0273                 if (fd < 0)
0274                     error(1, errno, "failed to accept");
0275                 ev.data.fd = fd;
0276                 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev))
0277                     error(1, errno, "failed epoll add");
0278                 continue;
0279             }
0280             i = recv(ev.data.fd, buf, sizeof(buf), 0);
0281             if (i != 1)
0282                 error(1, errno, "failed recv data");
0283             if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL))
0284                 error(1, errno, "failed epoll del");
0285             close(ev.data.fd);
0286             break;
0287         }
0288         close(send_fd);
0289     }
0290     for (i = 0; i < N_LISTEN; i++)
0291         close(rcv_fds[i]);
0292 }
0293 
0294 static void parse_opts(int argc, char **argv)
0295 {
0296     int c;
0297 
0298     while ((c = getopt(argc, argv, "46sr")) != -1) {
0299         switch (c) {
0300         case '4':
0301             do_ipv6 = false;
0302             break;
0303         case '6':
0304             do_ipv6 = true;
0305             break;
0306         case 's':
0307             do_sockopt = true;
0308             break;
0309         case 'r':
0310             do_rotate = true;
0311             key_len = KEY_LENGTH * 2;
0312             break;
0313         default:
0314             error(1, 0, "%s: parse error", argv[0]);
0315         }
0316     }
0317 }
0318 
0319 int main(int argc, char **argv)
0320 {
0321     parse_opts(argc, argv);
0322     proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR);
0323     if (proc_fd < 0)
0324         error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY);
0325     srand(time(NULL));
0326     if (do_ipv6)
0327         run_one_test(AF_INET6);
0328     else
0329         run_one_test(AF_INET);
0330     close(proc_fd);
0331     fprintf(stderr, "PASS\n");
0332     return 0;
0333 }