0001
0002
0003 #define _GNU_SOURCE
0004
0005 #include <arpa/inet.h>
0006 #include <errno.h>
0007 #include <error.h>
0008 #include <fcntl.h>
0009 #include <limits.h>
0010 #include <linux/filter.h>
0011 #include <linux/bpf.h>
0012 #include <linux/if_packet.h>
0013 #include <linux/if_vlan.h>
0014 #include <linux/virtio_net.h>
0015 #include <net/if.h>
0016 #include <net/ethernet.h>
0017 #include <netinet/ip.h>
0018 #include <netinet/udp.h>
0019 #include <poll.h>
0020 #include <sched.h>
0021 #include <stdbool.h>
0022 #include <stdint.h>
0023 #include <stdio.h>
0024 #include <stdlib.h>
0025 #include <string.h>
0026 #include <sys/mman.h>
0027 #include <sys/socket.h>
0028 #include <sys/stat.h>
0029 #include <sys/types.h>
0030 #include <unistd.h>
0031
0032 #include "psock_lib.h"
0033
0034 static bool cfg_use_bind;
0035 static bool cfg_use_csum_off;
0036 static bool cfg_use_csum_off_bad;
0037 static bool cfg_use_dgram;
0038 static bool cfg_use_gso;
0039 static bool cfg_use_qdisc_bypass;
0040 static bool cfg_use_vlan;
0041 static bool cfg_use_vnet;
0042
0043 static char *cfg_ifname = "lo";
0044 static int cfg_mtu = 1500;
0045 static int cfg_payload_len = DATA_LEN;
0046 static int cfg_truncate_len = INT_MAX;
0047 static uint16_t cfg_port = 8000;
0048
0049
0050 #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
0051
0052 static char tbuf[TEST_SZ], rbuf[TEST_SZ];
0053
0054 static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
0055 {
0056 unsigned long sum = 0;
0057 int i;
0058
0059 for (i = 0; i < num_u16; i++)
0060 sum += start[i];
0061
0062 return sum;
0063 }
0064
0065 static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
0066 unsigned long sum)
0067 {
0068 sum += add_csum_hword(start, num_u16);
0069
0070 while (sum >> 16)
0071 sum = (sum & 0xffff) + (sum >> 16);
0072
0073 return ~sum;
0074 }
0075
0076 static int build_vnet_header(void *header)
0077 {
0078 struct virtio_net_hdr *vh = header;
0079
0080 vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
0081
0082 if (cfg_use_csum_off) {
0083 vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
0084 vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
0085 vh->csum_offset = __builtin_offsetof(struct udphdr, check);
0086
0087
0088 if (cfg_use_csum_off_bad)
0089 vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
0090 vh->csum_offset - 1;
0091 }
0092
0093 if (cfg_use_gso) {
0094 vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
0095 vh->gso_size = cfg_mtu - sizeof(struct iphdr);
0096 }
0097
0098 return sizeof(*vh);
0099 }
0100
0101 static int build_eth_header(void *header)
0102 {
0103 struct ethhdr *eth = header;
0104
0105 if (cfg_use_vlan) {
0106 uint16_t *tag = header + ETH_HLEN;
0107
0108 eth->h_proto = htons(ETH_P_8021Q);
0109 tag[1] = htons(ETH_P_IP);
0110 return ETH_HLEN + 4;
0111 }
0112
0113 eth->h_proto = htons(ETH_P_IP);
0114 return ETH_HLEN;
0115 }
0116
0117 static int build_ipv4_header(void *header, int payload_len)
0118 {
0119 struct iphdr *iph = header;
0120
0121 iph->ihl = 5;
0122 iph->version = 4;
0123 iph->ttl = 8;
0124 iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
0125 iph->id = htons(1337);
0126 iph->protocol = IPPROTO_UDP;
0127 iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
0128 iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
0129 iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
0130
0131 return iph->ihl << 2;
0132 }
0133
0134 static int build_udp_header(void *header, int payload_len)
0135 {
0136 const int alen = sizeof(uint32_t);
0137 struct udphdr *udph = header;
0138 int len = sizeof(*udph) + payload_len;
0139
0140 udph->source = htons(9);
0141 udph->dest = htons(cfg_port);
0142 udph->len = htons(len);
0143
0144 if (cfg_use_csum_off)
0145 udph->check = build_ip_csum(header - (2 * alen), alen,
0146 htons(IPPROTO_UDP) + udph->len);
0147 else
0148 udph->check = 0;
0149
0150 return sizeof(*udph);
0151 }
0152
0153 static int build_packet(int payload_len)
0154 {
0155 int off = 0;
0156
0157 off += build_vnet_header(tbuf);
0158 off += build_eth_header(tbuf + off);
0159 off += build_ipv4_header(tbuf + off, payload_len);
0160 off += build_udp_header(tbuf + off, payload_len);
0161
0162 if (off + payload_len > sizeof(tbuf))
0163 error(1, 0, "payload length exceeds max");
0164
0165 memset(tbuf + off, DATA_CHAR, payload_len);
0166
0167 return off + payload_len;
0168 }
0169
0170 static void do_bind(int fd)
0171 {
0172 struct sockaddr_ll laddr = {0};
0173
0174 laddr.sll_family = AF_PACKET;
0175 laddr.sll_protocol = htons(ETH_P_IP);
0176 laddr.sll_ifindex = if_nametoindex(cfg_ifname);
0177 if (!laddr.sll_ifindex)
0178 error(1, errno, "if_nametoindex");
0179
0180 if (bind(fd, (void *)&laddr, sizeof(laddr)))
0181 error(1, errno, "bind");
0182 }
0183
0184 static void do_send(int fd, char *buf, int len)
0185 {
0186 int ret;
0187
0188 if (!cfg_use_vnet) {
0189 buf += sizeof(struct virtio_net_hdr);
0190 len -= sizeof(struct virtio_net_hdr);
0191 }
0192 if (cfg_use_dgram) {
0193 buf += ETH_HLEN;
0194 len -= ETH_HLEN;
0195 }
0196
0197 if (cfg_use_bind) {
0198 ret = write(fd, buf, len);
0199 } else {
0200 struct sockaddr_ll laddr = {0};
0201
0202 laddr.sll_protocol = htons(ETH_P_IP);
0203 laddr.sll_ifindex = if_nametoindex(cfg_ifname);
0204 if (!laddr.sll_ifindex)
0205 error(1, errno, "if_nametoindex");
0206
0207 ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
0208 }
0209
0210 if (ret == -1)
0211 error(1, errno, "write");
0212 if (ret != len)
0213 error(1, 0, "write: %u %u", ret, len);
0214
0215 fprintf(stderr, "tx: %u\n", ret);
0216 }
0217
0218 static int do_tx(void)
0219 {
0220 const int one = 1;
0221 int fd, len;
0222
0223 fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
0224 if (fd == -1)
0225 error(1, errno, "socket t");
0226
0227 if (cfg_use_bind)
0228 do_bind(fd);
0229
0230 if (cfg_use_qdisc_bypass &&
0231 setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
0232 error(1, errno, "setsockopt qdisc bypass");
0233
0234 if (cfg_use_vnet &&
0235 setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
0236 error(1, errno, "setsockopt vnet");
0237
0238 len = build_packet(cfg_payload_len);
0239
0240 if (cfg_truncate_len < len)
0241 len = cfg_truncate_len;
0242
0243 do_send(fd, tbuf, len);
0244
0245 if (close(fd))
0246 error(1, errno, "close t");
0247
0248 return len;
0249 }
0250
0251 static int setup_rx(void)
0252 {
0253 struct timeval tv = { .tv_usec = 100 * 1000 };
0254 struct sockaddr_in raddr = {0};
0255 int fd;
0256
0257 fd = socket(PF_INET, SOCK_DGRAM, 0);
0258 if (fd == -1)
0259 error(1, errno, "socket r");
0260
0261 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
0262 error(1, errno, "setsockopt rcv timeout");
0263
0264 raddr.sin_family = AF_INET;
0265 raddr.sin_port = htons(cfg_port);
0266 raddr.sin_addr.s_addr = htonl(INADDR_ANY);
0267
0268 if (bind(fd, (void *)&raddr, sizeof(raddr)))
0269 error(1, errno, "bind r");
0270
0271 return fd;
0272 }
0273
0274 static void do_rx(int fd, int expected_len, char *expected)
0275 {
0276 int ret;
0277
0278 ret = recv(fd, rbuf, sizeof(rbuf), 0);
0279 if (ret == -1)
0280 error(1, errno, "recv");
0281 if (ret != expected_len)
0282 error(1, 0, "recv: %u != %u", ret, expected_len);
0283
0284 if (memcmp(rbuf, expected, ret))
0285 error(1, 0, "recv: data mismatch");
0286
0287 fprintf(stderr, "rx: %u\n", ret);
0288 }
0289
0290 static int setup_sniffer(void)
0291 {
0292 struct timeval tv = { .tv_usec = 100 * 1000 };
0293 int fd;
0294
0295 fd = socket(PF_PACKET, SOCK_RAW, 0);
0296 if (fd == -1)
0297 error(1, errno, "socket p");
0298
0299 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
0300 error(1, errno, "setsockopt rcv timeout");
0301
0302 pair_udp_setfilter(fd);
0303 do_bind(fd);
0304
0305 return fd;
0306 }
0307
0308 static void parse_opts(int argc, char **argv)
0309 {
0310 int c;
0311
0312 while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
0313 switch (c) {
0314 case 'b':
0315 cfg_use_bind = true;
0316 break;
0317 case 'c':
0318 cfg_use_csum_off = true;
0319 break;
0320 case 'C':
0321 cfg_use_csum_off_bad = true;
0322 break;
0323 case 'd':
0324 cfg_use_dgram = true;
0325 break;
0326 case 'g':
0327 cfg_use_gso = true;
0328 break;
0329 case 'l':
0330 cfg_payload_len = strtoul(optarg, NULL, 0);
0331 break;
0332 case 'q':
0333 cfg_use_qdisc_bypass = true;
0334 break;
0335 case 't':
0336 cfg_truncate_len = strtoul(optarg, NULL, 0);
0337 break;
0338 case 'v':
0339 cfg_use_vnet = true;
0340 break;
0341 case 'V':
0342 cfg_use_vlan = true;
0343 break;
0344 default:
0345 error(1, 0, "%s: parse error", argv[0]);
0346 }
0347 }
0348
0349 if (cfg_use_vlan && cfg_use_dgram)
0350 error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
0351
0352 if (cfg_use_csum_off && !cfg_use_vnet)
0353 error(1, 0, "option csum offload (-c) requires vnet (-v)");
0354
0355 if (cfg_use_csum_off_bad && !cfg_use_csum_off)
0356 error(1, 0, "option csum bad (-C) requires csum offload (-c)");
0357
0358 if (cfg_use_gso && !cfg_use_csum_off)
0359 error(1, 0, "option gso (-g) requires csum offload (-c)");
0360 }
0361
0362 static void run_test(void)
0363 {
0364 int fdr, fds, total_len;
0365
0366 fdr = setup_rx();
0367 fds = setup_sniffer();
0368
0369 total_len = do_tx();
0370
0371
0372 if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
0373 do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
0374 tbuf + sizeof(struct virtio_net_hdr));
0375
0376 do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
0377
0378 if (close(fds))
0379 error(1, errno, "close s");
0380 if (close(fdr))
0381 error(1, errno, "close r");
0382 }
0383
0384 int main(int argc, char **argv)
0385 {
0386 parse_opts(argc, argv);
0387
0388 if (system("ip link set dev lo mtu 1500"))
0389 error(1, errno, "ip link set mtu");
0390 if (system("ip addr add dev lo 172.17.0.1/24"))
0391 error(1, errno, "ip addr add");
0392 if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
0393 error(1, errno, "sysctl lo.accept_local");
0394
0395 run_test();
0396
0397 fprintf(stderr, "OK\n\n");
0398 return 0;
0399 }