Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 
0003 #define _GNU_SOURCE
0004 
0005 #include <arpa/inet.h>
0006 #include <errno.h>
0007 #include <error.h>
0008 #include <fcntl.h>
0009 #include <limits.h>
0010 #include <linux/filter.h>
0011 #include <linux/bpf.h>
0012 #include <linux/if_packet.h>
0013 #include <linux/if_vlan.h>
0014 #include <linux/virtio_net.h>
0015 #include <net/if.h>
0016 #include <net/ethernet.h>
0017 #include <netinet/ip.h>
0018 #include <netinet/udp.h>
0019 #include <poll.h>
0020 #include <sched.h>
0021 #include <stdbool.h>
0022 #include <stdint.h>
0023 #include <stdio.h>
0024 #include <stdlib.h>
0025 #include <string.h>
0026 #include <sys/mman.h>
0027 #include <sys/socket.h>
0028 #include <sys/stat.h>
0029 #include <sys/types.h>
0030 #include <unistd.h>
0031 
0032 #include "psock_lib.h"
0033 
0034 static bool cfg_use_bind;
0035 static bool cfg_use_csum_off;
0036 static bool cfg_use_csum_off_bad;
0037 static bool cfg_use_dgram;
0038 static bool cfg_use_gso;
0039 static bool cfg_use_qdisc_bypass;
0040 static bool cfg_use_vlan;
0041 static bool cfg_use_vnet;
0042 
0043 static char *cfg_ifname = "lo";
0044 static int  cfg_mtu = 1500;
0045 static int  cfg_payload_len = DATA_LEN;
0046 static int  cfg_truncate_len = INT_MAX;
0047 static uint16_t cfg_port = 8000;
0048 
0049 /* test sending up to max mtu + 1 */
0050 #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
0051 
0052 static char tbuf[TEST_SZ], rbuf[TEST_SZ];
0053 
0054 static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
0055 {
0056     unsigned long sum = 0;
0057     int i;
0058 
0059     for (i = 0; i < num_u16; i++)
0060         sum += start[i];
0061 
0062     return sum;
0063 }
0064 
0065 static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
0066                   unsigned long sum)
0067 {
0068     sum += add_csum_hword(start, num_u16);
0069 
0070     while (sum >> 16)
0071         sum = (sum & 0xffff) + (sum >> 16);
0072 
0073     return ~sum;
0074 }
0075 
0076 static int build_vnet_header(void *header)
0077 {
0078     struct virtio_net_hdr *vh = header;
0079 
0080     vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
0081 
0082     if (cfg_use_csum_off) {
0083         vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
0084         vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
0085         vh->csum_offset = __builtin_offsetof(struct udphdr, check);
0086 
0087         /* position check field exactly one byte beyond end of packet */
0088         if (cfg_use_csum_off_bad)
0089             vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
0090                       vh->csum_offset - 1;
0091     }
0092 
0093     if (cfg_use_gso) {
0094         vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
0095         vh->gso_size = cfg_mtu - sizeof(struct iphdr);
0096     }
0097 
0098     return sizeof(*vh);
0099 }
0100 
0101 static int build_eth_header(void *header)
0102 {
0103     struct ethhdr *eth = header;
0104 
0105     if (cfg_use_vlan) {
0106         uint16_t *tag = header + ETH_HLEN;
0107 
0108         eth->h_proto = htons(ETH_P_8021Q);
0109         tag[1] = htons(ETH_P_IP);
0110         return ETH_HLEN + 4;
0111     }
0112 
0113     eth->h_proto = htons(ETH_P_IP);
0114     return ETH_HLEN;
0115 }
0116 
0117 static int build_ipv4_header(void *header, int payload_len)
0118 {
0119     struct iphdr *iph = header;
0120 
0121     iph->ihl = 5;
0122     iph->version = 4;
0123     iph->ttl = 8;
0124     iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
0125     iph->id = htons(1337);
0126     iph->protocol = IPPROTO_UDP;
0127     iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
0128     iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
0129     iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
0130 
0131     return iph->ihl << 2;
0132 }
0133 
0134 static int build_udp_header(void *header, int payload_len)
0135 {
0136     const int alen = sizeof(uint32_t);
0137     struct udphdr *udph = header;
0138     int len = sizeof(*udph) + payload_len;
0139 
0140     udph->source = htons(9);
0141     udph->dest = htons(cfg_port);
0142     udph->len = htons(len);
0143 
0144     if (cfg_use_csum_off)
0145         udph->check = build_ip_csum(header - (2 * alen), alen,
0146                         htons(IPPROTO_UDP) + udph->len);
0147     else
0148         udph->check = 0;
0149 
0150     return sizeof(*udph);
0151 }
0152 
0153 static int build_packet(int payload_len)
0154 {
0155     int off = 0;
0156 
0157     off += build_vnet_header(tbuf);
0158     off += build_eth_header(tbuf + off);
0159     off += build_ipv4_header(tbuf + off, payload_len);
0160     off += build_udp_header(tbuf + off, payload_len);
0161 
0162     if (off + payload_len > sizeof(tbuf))
0163         error(1, 0, "payload length exceeds max");
0164 
0165     memset(tbuf + off, DATA_CHAR, payload_len);
0166 
0167     return off + payload_len;
0168 }
0169 
0170 static void do_bind(int fd)
0171 {
0172     struct sockaddr_ll laddr = {0};
0173 
0174     laddr.sll_family = AF_PACKET;
0175     laddr.sll_protocol = htons(ETH_P_IP);
0176     laddr.sll_ifindex = if_nametoindex(cfg_ifname);
0177     if (!laddr.sll_ifindex)
0178         error(1, errno, "if_nametoindex");
0179 
0180     if (bind(fd, (void *)&laddr, sizeof(laddr)))
0181         error(1, errno, "bind");
0182 }
0183 
0184 static void do_send(int fd, char *buf, int len)
0185 {
0186     int ret;
0187 
0188     if (!cfg_use_vnet) {
0189         buf += sizeof(struct virtio_net_hdr);
0190         len -= sizeof(struct virtio_net_hdr);
0191     }
0192     if (cfg_use_dgram) {
0193         buf += ETH_HLEN;
0194         len -= ETH_HLEN;
0195     }
0196 
0197     if (cfg_use_bind) {
0198         ret = write(fd, buf, len);
0199     } else {
0200         struct sockaddr_ll laddr = {0};
0201 
0202         laddr.sll_protocol = htons(ETH_P_IP);
0203         laddr.sll_ifindex = if_nametoindex(cfg_ifname);
0204         if (!laddr.sll_ifindex)
0205             error(1, errno, "if_nametoindex");
0206 
0207         ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
0208     }
0209 
0210     if (ret == -1)
0211         error(1, errno, "write");
0212     if (ret != len)
0213         error(1, 0, "write: %u %u", ret, len);
0214 
0215     fprintf(stderr, "tx: %u\n", ret);
0216 }
0217 
0218 static int do_tx(void)
0219 {
0220     const int one = 1;
0221     int fd, len;
0222 
0223     fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
0224     if (fd == -1)
0225         error(1, errno, "socket t");
0226 
0227     if (cfg_use_bind)
0228         do_bind(fd);
0229 
0230     if (cfg_use_qdisc_bypass &&
0231         setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
0232         error(1, errno, "setsockopt qdisc bypass");
0233 
0234     if (cfg_use_vnet &&
0235         setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
0236         error(1, errno, "setsockopt vnet");
0237 
0238     len = build_packet(cfg_payload_len);
0239 
0240     if (cfg_truncate_len < len)
0241         len = cfg_truncate_len;
0242 
0243     do_send(fd, tbuf, len);
0244 
0245     if (close(fd))
0246         error(1, errno, "close t");
0247 
0248     return len;
0249 }
0250 
0251 static int setup_rx(void)
0252 {
0253     struct timeval tv = { .tv_usec = 100 * 1000 };
0254     struct sockaddr_in raddr = {0};
0255     int fd;
0256 
0257     fd = socket(PF_INET, SOCK_DGRAM, 0);
0258     if (fd == -1)
0259         error(1, errno, "socket r");
0260 
0261     if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
0262         error(1, errno, "setsockopt rcv timeout");
0263 
0264     raddr.sin_family = AF_INET;
0265     raddr.sin_port = htons(cfg_port);
0266     raddr.sin_addr.s_addr = htonl(INADDR_ANY);
0267 
0268     if (bind(fd, (void *)&raddr, sizeof(raddr)))
0269         error(1, errno, "bind r");
0270 
0271     return fd;
0272 }
0273 
0274 static void do_rx(int fd, int expected_len, char *expected)
0275 {
0276     int ret;
0277 
0278     ret = recv(fd, rbuf, sizeof(rbuf), 0);
0279     if (ret == -1)
0280         error(1, errno, "recv");
0281     if (ret != expected_len)
0282         error(1, 0, "recv: %u != %u", ret, expected_len);
0283 
0284     if (memcmp(rbuf, expected, ret))
0285         error(1, 0, "recv: data mismatch");
0286 
0287     fprintf(stderr, "rx: %u\n", ret);
0288 }
0289 
0290 static int setup_sniffer(void)
0291 {
0292     struct timeval tv = { .tv_usec = 100 * 1000 };
0293     int fd;
0294 
0295     fd = socket(PF_PACKET, SOCK_RAW, 0);
0296     if (fd == -1)
0297         error(1, errno, "socket p");
0298 
0299     if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
0300         error(1, errno, "setsockopt rcv timeout");
0301 
0302     pair_udp_setfilter(fd);
0303     do_bind(fd);
0304 
0305     return fd;
0306 }
0307 
0308 static void parse_opts(int argc, char **argv)
0309 {
0310     int c;
0311 
0312     while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
0313         switch (c) {
0314         case 'b':
0315             cfg_use_bind = true;
0316             break;
0317         case 'c':
0318             cfg_use_csum_off = true;
0319             break;
0320         case 'C':
0321             cfg_use_csum_off_bad = true;
0322             break;
0323         case 'd':
0324             cfg_use_dgram = true;
0325             break;
0326         case 'g':
0327             cfg_use_gso = true;
0328             break;
0329         case 'l':
0330             cfg_payload_len = strtoul(optarg, NULL, 0);
0331             break;
0332         case 'q':
0333             cfg_use_qdisc_bypass = true;
0334             break;
0335         case 't':
0336             cfg_truncate_len = strtoul(optarg, NULL, 0);
0337             break;
0338         case 'v':
0339             cfg_use_vnet = true;
0340             break;
0341         case 'V':
0342             cfg_use_vlan = true;
0343             break;
0344         default:
0345             error(1, 0, "%s: parse error", argv[0]);
0346         }
0347     }
0348 
0349     if (cfg_use_vlan && cfg_use_dgram)
0350         error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
0351 
0352     if (cfg_use_csum_off && !cfg_use_vnet)
0353         error(1, 0, "option csum offload (-c) requires vnet (-v)");
0354 
0355     if (cfg_use_csum_off_bad && !cfg_use_csum_off)
0356         error(1, 0, "option csum bad (-C) requires csum offload (-c)");
0357 
0358     if (cfg_use_gso && !cfg_use_csum_off)
0359         error(1, 0, "option gso (-g) requires csum offload (-c)");
0360 }
0361 
0362 static void run_test(void)
0363 {
0364     int fdr, fds, total_len;
0365 
0366     fdr = setup_rx();
0367     fds = setup_sniffer();
0368 
0369     total_len = do_tx();
0370 
0371     /* BPF filter accepts only this length, vlan changes MAC */
0372     if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
0373         do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
0374               tbuf + sizeof(struct virtio_net_hdr));
0375 
0376     do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
0377 
0378     if (close(fds))
0379         error(1, errno, "close s");
0380     if (close(fdr))
0381         error(1, errno, "close r");
0382 }
0383 
0384 int main(int argc, char **argv)
0385 {
0386     parse_opts(argc, argv);
0387 
0388     if (system("ip link set dev lo mtu 1500"))
0389         error(1, errno, "ip link set mtu");
0390     if (system("ip addr add dev lo 172.17.0.1/24"))
0391         error(1, errno, "ip addr add");
0392     if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
0393         error(1, errno, "sysctl lo.accept_local");
0394 
0395     run_test();
0396 
0397     fprintf(stderr, "OK\n\n");
0398     return 0;
0399 }