Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /*
0003  * ipsec.c - Check xfrm on veth inside a net-ns.
0004  * Copyright (c) 2018 Dmitry Safonov
0005  */
0006 
0007 #define _GNU_SOURCE
0008 
0009 #include <arpa/inet.h>
0010 #include <asm/types.h>
0011 #include <errno.h>
0012 #include <fcntl.h>
0013 #include <limits.h>
0014 #include <linux/limits.h>
0015 #include <linux/netlink.h>
0016 #include <linux/random.h>
0017 #include <linux/rtnetlink.h>
0018 #include <linux/veth.h>
0019 #include <linux/xfrm.h>
0020 #include <netinet/in.h>
0021 #include <net/if.h>
0022 #include <sched.h>
0023 #include <stdbool.h>
0024 #include <stdint.h>
0025 #include <stdio.h>
0026 #include <stdlib.h>
0027 #include <string.h>
0028 #include <sys/mman.h>
0029 #include <sys/socket.h>
0030 #include <sys/stat.h>
0031 #include <sys/syscall.h>
0032 #include <sys/types.h>
0033 #include <sys/wait.h>
0034 #include <time.h>
0035 #include <unistd.h>
0036 
0037 #include "../kselftest.h"
0038 
0039 #define printk(fmt, ...)                        \
0040     ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
0041 
0042 #define pr_err(fmt, ...)    printk(fmt ": %m", ##__VA_ARGS__)
0043 
0044 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
0045 
0046 #define IPV4_STR_SZ 16  /* xxx.xxx.xxx.xxx is longest + \0 */
0047 #define MAX_PAYLOAD 2048
0048 #define XFRM_ALGO_KEY_BUF_SIZE  512
0049 #define MAX_PROCESSES   (1 << 14) /* /16 mask divided by /30 subnets */
0050 #define INADDR_A    ((in_addr_t) 0x0a000000) /* 10.0.0.0 */
0051 #define INADDR_B    ((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
0052 
0053 /* /30 mask for one veth connection */
0054 #define PREFIX_LEN  30
0055 #define child_ip(nr)    (4*nr + 1)
0056 #define grchild_ip(nr)  (4*nr + 2)
0057 
0058 #define VETH_FMT    "ktst-%d"
0059 #define VETH_LEN    12
0060 
0061 static int nsfd_parent  = -1;
0062 static int nsfd_childa  = -1;
0063 static int nsfd_childb  = -1;
0064 static long page_size;
0065 
0066 /*
0067  * ksft_cnt is static in kselftest, so isn't shared with children.
0068  * We have to send a test result back to parent and count there.
0069  * results_fd is a pipe with test feedback from children.
0070  */
0071 static int results_fd[2];
0072 
0073 const unsigned int ping_delay_nsec  = 50 * 1000 * 1000;
0074 const unsigned int ping_timeout     = 300;
0075 const unsigned int ping_count       = 100;
0076 const unsigned int ping_success     = 80;
0077 
0078 static void randomize_buffer(void *buf, size_t buflen)
0079 {
0080     int *p = (int *)buf;
0081     size_t words = buflen / sizeof(int);
0082     size_t leftover = buflen % sizeof(int);
0083 
0084     if (!buflen)
0085         return;
0086 
0087     while (words--)
0088         *p++ = rand();
0089 
0090     if (leftover) {
0091         int tmp = rand();
0092 
0093         memcpy(buf + buflen - leftover, &tmp, leftover);
0094     }
0095 
0096     return;
0097 }
0098 
0099 static int unshare_open(void)
0100 {
0101     const char *netns_path = "/proc/self/ns/net";
0102     int fd;
0103 
0104     if (unshare(CLONE_NEWNET) != 0) {
0105         pr_err("unshare()");
0106         return -1;
0107     }
0108 
0109     fd = open(netns_path, O_RDONLY);
0110     if (fd <= 0) {
0111         pr_err("open(%s)", netns_path);
0112         return -1;
0113     }
0114 
0115     return fd;
0116 }
0117 
0118 static int switch_ns(int fd)
0119 {
0120     if (setns(fd, CLONE_NEWNET)) {
0121         pr_err("setns()");
0122         return -1;
0123     }
0124     return 0;
0125 }
0126 
0127 /*
0128  * Running the test inside a new parent net namespace to bother less
0129  * about cleanup on error-path.
0130  */
0131 static int init_namespaces(void)
0132 {
0133     nsfd_parent = unshare_open();
0134     if (nsfd_parent <= 0)
0135         return -1;
0136 
0137     nsfd_childa = unshare_open();
0138     if (nsfd_childa <= 0)
0139         return -1;
0140 
0141     if (switch_ns(nsfd_parent))
0142         return -1;
0143 
0144     nsfd_childb = unshare_open();
0145     if (nsfd_childb <= 0)
0146         return -1;
0147 
0148     if (switch_ns(nsfd_parent))
0149         return -1;
0150     return 0;
0151 }
0152 
0153 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
0154 {
0155     if (*sock > 0) {
0156         seq_nr++;
0157         return 0;
0158     }
0159 
0160     *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
0161     if (*sock <= 0) {
0162         pr_err("socket(AF_NETLINK)");
0163         return -1;
0164     }
0165 
0166     randomize_buffer(seq_nr, sizeof(*seq_nr));
0167 
0168     return 0;
0169 }
0170 
0171 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
0172 {
0173     return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
0174 }
0175 
0176 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
0177         unsigned short rta_type, const void *payload, size_t size)
0178 {
0179     /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
0180     struct rtattr *attr = rtattr_hdr(nh);
0181     size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
0182 
0183     if (req_sz < nl_size) {
0184         printk("req buf is too small: %zu < %zu", req_sz, nl_size);
0185         return -1;
0186     }
0187     nh->nlmsg_len = nl_size;
0188 
0189     attr->rta_len = RTA_LENGTH(size);
0190     attr->rta_type = rta_type;
0191     memcpy(RTA_DATA(attr), payload, size);
0192 
0193     return 0;
0194 }
0195 
0196 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
0197         unsigned short rta_type, const void *payload, size_t size)
0198 {
0199     struct rtattr *ret = rtattr_hdr(nh);
0200 
0201     if (rtattr_pack(nh, req_sz, rta_type, payload, size))
0202         return 0;
0203 
0204     return ret;
0205 }
0206 
0207 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
0208         unsigned short rta_type)
0209 {
0210     return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
0211 }
0212 
0213 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
0214 {
0215     char *nlmsg_end = (char *)nh + nh->nlmsg_len;
0216 
0217     attr->rta_len = nlmsg_end - (char *)attr;
0218 }
0219 
0220 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
0221         const char *peer, int ns)
0222 {
0223     struct ifinfomsg pi;
0224     struct rtattr *peer_attr;
0225 
0226     memset(&pi, 0, sizeof(pi));
0227     pi.ifi_family   = AF_UNSPEC;
0228     pi.ifi_change   = 0xFFFFFFFF;
0229 
0230     peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
0231     if (!peer_attr)
0232         return -1;
0233 
0234     if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
0235         return -1;
0236 
0237     if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
0238         return -1;
0239 
0240     rtattr_end(nh, peer_attr);
0241 
0242     return 0;
0243 }
0244 
0245 static int netlink_check_answer(int sock)
0246 {
0247     struct nlmsgerror {
0248         struct nlmsghdr hdr;
0249         int error;
0250         struct nlmsghdr orig_msg;
0251     } answer;
0252 
0253     if (recv(sock, &answer, sizeof(answer), 0) < 0) {
0254         pr_err("recv()");
0255         return -1;
0256     } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
0257         printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
0258         return -1;
0259     } else if (answer.error) {
0260         printk("NLMSG_ERROR: %d: %s",
0261             answer.error, strerror(-answer.error));
0262         return answer.error;
0263     }
0264 
0265     return 0;
0266 }
0267 
0268 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
0269         const char *peerb, int ns_b)
0270 {
0271     uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
0272     struct {
0273         struct nlmsghdr     nh;
0274         struct ifinfomsg    info;
0275         char            attrbuf[MAX_PAYLOAD];
0276     } req;
0277     const char veth_type[] = "veth";
0278     struct rtattr *link_info, *info_data;
0279 
0280     memset(&req, 0, sizeof(req));
0281     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.info));
0282     req.nh.nlmsg_type   = RTM_NEWLINK;
0283     req.nh.nlmsg_flags  = flags;
0284     req.nh.nlmsg_seq    = seq;
0285     req.info.ifi_family = AF_UNSPEC;
0286     req.info.ifi_change = 0xFFFFFFFF;
0287 
0288     if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
0289         return -1;
0290 
0291     if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
0292         return -1;
0293 
0294     link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
0295     if (!link_info)
0296         return -1;
0297 
0298     if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
0299         return -1;
0300 
0301     info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
0302     if (!info_data)
0303         return -1;
0304 
0305     if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
0306         return -1;
0307 
0308     rtattr_end(&req.nh, info_data);
0309     rtattr_end(&req.nh, link_info);
0310 
0311     if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
0312         pr_err("send()");
0313         return -1;
0314     }
0315     return netlink_check_answer(sock);
0316 }
0317 
0318 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
0319         struct in_addr addr, uint8_t prefix)
0320 {
0321     uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
0322     struct {
0323         struct nlmsghdr     nh;
0324         struct ifaddrmsg    info;
0325         char            attrbuf[MAX_PAYLOAD];
0326     } req;
0327 
0328     memset(&req, 0, sizeof(req));
0329     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.info));
0330     req.nh.nlmsg_type   = RTM_NEWADDR;
0331     req.nh.nlmsg_flags  = flags;
0332     req.nh.nlmsg_seq    = seq;
0333     req.info.ifa_family = AF_INET;
0334     req.info.ifa_prefixlen  = prefix;
0335     req.info.ifa_index  = if_nametoindex(intf);
0336 
0337 #ifdef DEBUG
0338     {
0339         char addr_str[IPV4_STR_SZ] = {};
0340 
0341         strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
0342 
0343         printk("ip addr set %s", addr_str);
0344     }
0345 #endif
0346 
0347     if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
0348         return -1;
0349 
0350     if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
0351         return -1;
0352 
0353     if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
0354         pr_err("send()");
0355         return -1;
0356     }
0357     return netlink_check_answer(sock);
0358 }
0359 
0360 static int link_set_up(int sock, uint32_t seq, const char *intf)
0361 {
0362     struct {
0363         struct nlmsghdr     nh;
0364         struct ifinfomsg    info;
0365         char            attrbuf[MAX_PAYLOAD];
0366     } req;
0367 
0368     memset(&req, 0, sizeof(req));
0369     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.info));
0370     req.nh.nlmsg_type   = RTM_NEWLINK;
0371     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
0372     req.nh.nlmsg_seq    = seq;
0373     req.info.ifi_family = AF_UNSPEC;
0374     req.info.ifi_change = 0xFFFFFFFF;
0375     req.info.ifi_index  = if_nametoindex(intf);
0376     req.info.ifi_flags  = IFF_UP;
0377     req.info.ifi_change = IFF_UP;
0378 
0379     if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
0380         pr_err("send()");
0381         return -1;
0382     }
0383     return netlink_check_answer(sock);
0384 }
0385 
0386 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
0387         struct in_addr src, struct in_addr dst)
0388 {
0389     struct {
0390         struct nlmsghdr nh;
0391         struct rtmsg    rt;
0392         char        attrbuf[MAX_PAYLOAD];
0393     } req;
0394     unsigned int index = if_nametoindex(intf);
0395 
0396     memset(&req, 0, sizeof(req));
0397     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.rt));
0398     req.nh.nlmsg_type   = RTM_NEWROUTE;
0399     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
0400     req.nh.nlmsg_seq    = seq;
0401     req.rt.rtm_family   = AF_INET;
0402     req.rt.rtm_dst_len  = 32;
0403     req.rt.rtm_table    = RT_TABLE_MAIN;
0404     req.rt.rtm_protocol = RTPROT_BOOT;
0405     req.rt.rtm_scope    = RT_SCOPE_LINK;
0406     req.rt.rtm_type     = RTN_UNICAST;
0407 
0408     if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
0409         return -1;
0410 
0411     if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
0412         return -1;
0413 
0414     if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
0415         return -1;
0416 
0417     if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
0418         pr_err("send()");
0419         return -1;
0420     }
0421 
0422     return netlink_check_answer(sock);
0423 }
0424 
0425 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
0426         struct in_addr tunsrc, struct in_addr tundst)
0427 {
0428     if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
0429             tunsrc, PREFIX_LEN)) {
0430         printk("Failed to set ipv4 addr");
0431         return -1;
0432     }
0433 
0434     if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
0435         printk("Failed to set ipv4 route");
0436         return -1;
0437     }
0438 
0439     return 0;
0440 }
0441 
0442 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
0443 {
0444     struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
0445     struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
0446     struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
0447     int route_sock = -1, ret = -1;
0448     uint32_t route_seq;
0449 
0450     if (switch_ns(nsfd))
0451         return -1;
0452 
0453     if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
0454         printk("Failed to open netlink route socket in child");
0455         return -1;
0456     }
0457 
0458     if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
0459         printk("Failed to set ipv4 addr");
0460         goto err;
0461     }
0462 
0463     if (link_set_up(route_sock, route_seq++, veth)) {
0464         printk("Failed to bring up %s", veth);
0465         goto err;
0466     }
0467 
0468     if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
0469         printk("Failed to add tunnel route on %s", veth);
0470         goto err;
0471     }
0472     ret = 0;
0473 
0474 err:
0475     close(route_sock);
0476     return ret;
0477 }
0478 
0479 #define ALGO_LEN    64
0480 enum desc_type {
0481     CREATE_TUNNEL   = 0,
0482     ALLOCATE_SPI,
0483     MONITOR_ACQUIRE,
0484     EXPIRE_STATE,
0485     EXPIRE_POLICY,
0486     SPDINFO_ATTRS,
0487 };
0488 const char *desc_name[] = {
0489     "create tunnel",
0490     "alloc spi",
0491     "monitor acquire",
0492     "expire state",
0493     "expire policy",
0494     "spdinfo attributes",
0495     ""
0496 };
0497 struct xfrm_desc {
0498     enum desc_type  type;
0499     uint8_t     proto;
0500     char        a_algo[ALGO_LEN];
0501     char        e_algo[ALGO_LEN];
0502     char        c_algo[ALGO_LEN];
0503     char        ae_algo[ALGO_LEN];
0504     unsigned int    icv_len;
0505     /* unsigned key_len; */
0506 };
0507 
0508 enum msg_type {
0509     MSG_ACK     = 0,
0510     MSG_EXIT,
0511     MSG_PING,
0512     MSG_XFRM_PREPARE,
0513     MSG_XFRM_ADD,
0514     MSG_XFRM_DEL,
0515     MSG_XFRM_CLEANUP,
0516 };
0517 
0518 struct test_desc {
0519     enum msg_type type;
0520     union {
0521         struct {
0522             in_addr_t reply_ip;
0523             unsigned int port;
0524         } ping;
0525         struct xfrm_desc xfrm_desc;
0526     } body;
0527 };
0528 
0529 struct test_result {
0530     struct xfrm_desc desc;
0531     unsigned int res;
0532 };
0533 
0534 static void write_test_result(unsigned int res, struct xfrm_desc *d)
0535 {
0536     struct test_result tr = {};
0537     ssize_t ret;
0538 
0539     tr.desc = *d;
0540     tr.res = res;
0541 
0542     ret = write(results_fd[1], &tr, sizeof(tr));
0543     if (ret != sizeof(tr))
0544         pr_err("Failed to write the result in pipe %zd", ret);
0545 }
0546 
0547 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
0548 {
0549     ssize_t bytes = write(fd, msg, sizeof(*msg));
0550 
0551     /* Make sure that write/read is atomic to a pipe */
0552     BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
0553 
0554     if (bytes < 0) {
0555         pr_err("write()");
0556         if (exit_of_fail)
0557             exit(KSFT_FAIL);
0558     }
0559     if (bytes != sizeof(*msg)) {
0560         pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
0561         if (exit_of_fail)
0562             exit(KSFT_FAIL);
0563     }
0564 }
0565 
0566 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
0567 {
0568     ssize_t bytes = read(fd, msg, sizeof(*msg));
0569 
0570     if (bytes < 0) {
0571         pr_err("read()");
0572         if (exit_of_fail)
0573             exit(KSFT_FAIL);
0574     }
0575     if (bytes != sizeof(*msg)) {
0576         pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
0577         if (exit_of_fail)
0578             exit(KSFT_FAIL);
0579     }
0580 }
0581 
0582 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
0583         unsigned int *server_port, int sock[2])
0584 {
0585     struct sockaddr_in server;
0586     struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
0587     socklen_t s_len = sizeof(server);
0588 
0589     sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
0590     if (sock[0] < 0) {
0591         pr_err("socket()");
0592         return -1;
0593     }
0594 
0595     server.sin_family   = AF_INET;
0596     server.sin_port     = 0;
0597     memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
0598 
0599     if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
0600         pr_err("bind()");
0601         goto err_close_server;
0602     }
0603 
0604     if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
0605         pr_err("getsockname()");
0606         goto err_close_server;
0607     }
0608 
0609     *server_port = ntohs(server.sin_port);
0610 
0611     if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
0612         pr_err("setsockopt()");
0613         goto err_close_server;
0614     }
0615 
0616     sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
0617     if (sock[1] < 0) {
0618         pr_err("socket()");
0619         goto err_close_server;
0620     }
0621 
0622     return 0;
0623 
0624 err_close_server:
0625     close(sock[0]);
0626     return -1;
0627 }
0628 
0629 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
0630         char *buf, size_t buf_len)
0631 {
0632     struct sockaddr_in server;
0633     const struct sockaddr *dest_addr = (struct sockaddr *)&server;
0634     char *sock_buf[buf_len];
0635     ssize_t r_bytes, s_bytes;
0636 
0637     server.sin_family   = AF_INET;
0638     server.sin_port     = htons(port);
0639     server.sin_addr.s_addr  = dest_ip;
0640 
0641     s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
0642     if (s_bytes < 0) {
0643         pr_err("sendto()");
0644         return -1;
0645     } else if (s_bytes != buf_len) {
0646         printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
0647         return -1;
0648     }
0649 
0650     r_bytes = recv(sock[0], sock_buf, buf_len, 0);
0651     if (r_bytes < 0) {
0652         if (errno != EAGAIN)
0653             pr_err("recv()");
0654         return -1;
0655     } else if (r_bytes == 0) { /* EOF */
0656         printk("EOF on reply to ping");
0657         return -1;
0658     } else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
0659         printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
0660         return -1;
0661     }
0662 
0663     return 0;
0664 }
0665 
0666 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
0667         char *buf, size_t buf_len)
0668 {
0669     struct sockaddr_in server;
0670     const struct sockaddr *dest_addr = (struct sockaddr *)&server;
0671     char *sock_buf[buf_len];
0672     ssize_t r_bytes, s_bytes;
0673 
0674     server.sin_family   = AF_INET;
0675     server.sin_port     = htons(port);
0676     server.sin_addr.s_addr  = dest_ip;
0677 
0678     r_bytes = recv(sock[0], sock_buf, buf_len, 0);
0679     if (r_bytes < 0) {
0680         if (errno != EAGAIN)
0681             pr_err("recv()");
0682         return -1;
0683     }
0684     if (r_bytes == 0) { /* EOF */
0685         printk("EOF on reply to ping");
0686         return -1;
0687     }
0688     if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
0689         printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
0690         return -1;
0691     }
0692 
0693     s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
0694     if (s_bytes < 0) {
0695         pr_err("sendto()");
0696         return -1;
0697     } else if (s_bytes != buf_len) {
0698         printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
0699         return -1;
0700     }
0701 
0702     return 0;
0703 }
0704 
0705 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
0706         char *buf, size_t buf_len);
0707 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
0708         bool init_side, int d_port, in_addr_t to, ping_f func)
0709 {
0710     struct test_desc msg;
0711     unsigned int s_port, i, ping_succeeded = 0;
0712     int ping_sock[2];
0713     char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
0714 
0715     if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
0716         printk("Failed to init ping");
0717         return -1;
0718     }
0719 
0720     memset(&msg, 0, sizeof(msg));
0721     msg.type        = MSG_PING;
0722     msg.body.ping.port  = s_port;
0723     memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
0724 
0725     write_msg(cmd_fd, &msg, 0);
0726     if (init_side) {
0727         /* The other end sends ip to ping */
0728         read_msg(cmd_fd, &msg, 0);
0729         if (msg.type != MSG_PING)
0730             return -1;
0731         to = msg.body.ping.reply_ip;
0732         d_port = msg.body.ping.port;
0733     }
0734 
0735     for (i = 0; i < ping_count ; i++) {
0736         struct timespec sleep_time = {
0737             .tv_sec = 0,
0738             .tv_nsec = ping_delay_nsec,
0739         };
0740 
0741         ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
0742         nanosleep(&sleep_time, 0);
0743     }
0744 
0745     close(ping_sock[0]);
0746     close(ping_sock[1]);
0747 
0748     strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
0749     strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
0750 
0751     if (ping_succeeded < ping_success) {
0752         printk("ping (%s) %s->%s failed %u/%u times",
0753             init_side ? "send" : "reply", from_str, to_str,
0754             ping_count - ping_succeeded, ping_count);
0755         return -1;
0756     }
0757 
0758 #ifdef DEBUG
0759     printk("ping (%s) %s->%s succeeded %u/%u times",
0760         init_side ? "send" : "reply", from_str, to_str,
0761         ping_succeeded, ping_count);
0762 #endif
0763 
0764     return 0;
0765 }
0766 
0767 static int xfrm_fill_key(char *name, char *buf,
0768         size_t buf_len, unsigned int *key_len)
0769 {
0770     /* TODO: use set/map instead */
0771     if (strncmp(name, "digest_null", ALGO_LEN) == 0)
0772         *key_len = 0;
0773     else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
0774         *key_len = 0;
0775     else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
0776         *key_len = 64;
0777     else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
0778         *key_len = 128;
0779     else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
0780         *key_len = 128;
0781     else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
0782         *key_len = 128;
0783     else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
0784         *key_len = 128;
0785     else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
0786         *key_len = 128;
0787     else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
0788         *key_len = 160;
0789     else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
0790         *key_len = 160;
0791     else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
0792         *key_len = 192;
0793     else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
0794         *key_len = 256;
0795     else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
0796         *key_len = 256;
0797     else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
0798         *key_len = 256;
0799     else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
0800         *key_len = 256;
0801     else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
0802         *key_len = 288;
0803     else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
0804         *key_len = 384;
0805     else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
0806         *key_len = 448;
0807     else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
0808         *key_len = 512;
0809     else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
0810         *key_len = 160;
0811     else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
0812         *key_len = 160;
0813     else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
0814         *key_len = 152;
0815     else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
0816         *key_len = 224;
0817     else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
0818         *key_len = 224;
0819     else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
0820         *key_len = 216;
0821     else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
0822         *key_len = 288;
0823     else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
0824         *key_len = 288;
0825     else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
0826         *key_len = 280;
0827     else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
0828         *key_len = 0;
0829 
0830     if (*key_len > buf_len) {
0831         printk("Can't pack a key - too big for buffer");
0832         return -1;
0833     }
0834 
0835     randomize_buffer(buf, *key_len);
0836 
0837     return 0;
0838 }
0839 
0840 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
0841         struct xfrm_desc *desc)
0842 {
0843     struct {
0844         union {
0845             struct xfrm_algo    alg;
0846             struct xfrm_algo_aead   aead;
0847             struct xfrm_algo_auth   auth;
0848         } u;
0849         char buf[XFRM_ALGO_KEY_BUF_SIZE];
0850     } alg = {};
0851     size_t alen, elen, clen, aelen;
0852     unsigned short type;
0853 
0854     alen = strlen(desc->a_algo);
0855     elen = strlen(desc->e_algo);
0856     clen = strlen(desc->c_algo);
0857     aelen = strlen(desc->ae_algo);
0858 
0859     /* Verify desc */
0860     switch (desc->proto) {
0861     case IPPROTO_AH:
0862         if (!alen || elen || clen || aelen) {
0863             printk("BUG: buggy ah desc");
0864             return -1;
0865         }
0866         strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
0867         if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
0868                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
0869             return -1;
0870         type = XFRMA_ALG_AUTH;
0871         break;
0872     case IPPROTO_COMP:
0873         if (!clen || elen || alen || aelen) {
0874             printk("BUG: buggy comp desc");
0875             return -1;
0876         }
0877         strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
0878         if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
0879                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
0880             return -1;
0881         type = XFRMA_ALG_COMP;
0882         break;
0883     case IPPROTO_ESP:
0884         if (!((alen && elen) ^ aelen) || clen) {
0885             printk("BUG: buggy esp desc");
0886             return -1;
0887         }
0888         if (aelen) {
0889             alg.u.aead.alg_icv_len = desc->icv_len;
0890             strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
0891             if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
0892                         sizeof(alg.buf), &alg.u.aead.alg_key_len))
0893                 return -1;
0894             type = XFRMA_ALG_AEAD;
0895         } else {
0896 
0897             strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
0898             type = XFRMA_ALG_CRYPT;
0899             if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
0900                         sizeof(alg.buf), &alg.u.alg.alg_key_len))
0901                 return -1;
0902             if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
0903                 return -1;
0904 
0905             strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
0906             type = XFRMA_ALG_AUTH;
0907             if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
0908                         sizeof(alg.buf), &alg.u.alg.alg_key_len))
0909                 return -1;
0910         }
0911         break;
0912     default:
0913         printk("BUG: unknown proto in desc");
0914         return -1;
0915     }
0916 
0917     if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
0918         return -1;
0919 
0920     return 0;
0921 }
0922 
0923 static inline uint32_t gen_spi(struct in_addr src)
0924 {
0925     return htonl(inet_lnaof(src));
0926 }
0927 
0928 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
0929         struct in_addr src, struct in_addr dst,
0930         struct xfrm_desc *desc)
0931 {
0932     struct {
0933         struct nlmsghdr     nh;
0934         struct xfrm_usersa_info info;
0935         char            attrbuf[MAX_PAYLOAD];
0936     } req;
0937 
0938     memset(&req, 0, sizeof(req));
0939     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.info));
0940     req.nh.nlmsg_type   = XFRM_MSG_NEWSA;
0941     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
0942     req.nh.nlmsg_seq    = seq;
0943 
0944     /* Fill selector. */
0945     memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
0946     memcpy(&req.info.sel.saddr, &src, sizeof(src));
0947     req.info.sel.family     = AF_INET;
0948     req.info.sel.prefixlen_d    = PREFIX_LEN;
0949     req.info.sel.prefixlen_s    = PREFIX_LEN;
0950 
0951     /* Fill id */
0952     memcpy(&req.info.id.daddr, &dst, sizeof(dst));
0953     /* Note: zero-spi cannot be deleted */
0954     req.info.id.spi = spi;
0955     req.info.id.proto   = desc->proto;
0956 
0957     memcpy(&req.info.saddr, &src, sizeof(src));
0958 
0959     /* Fill lifteme_cfg */
0960     req.info.lft.soft_byte_limit    = XFRM_INF;
0961     req.info.lft.hard_byte_limit    = XFRM_INF;
0962     req.info.lft.soft_packet_limit  = XFRM_INF;
0963     req.info.lft.hard_packet_limit  = XFRM_INF;
0964 
0965     req.info.family     = AF_INET;
0966     req.info.mode       = XFRM_MODE_TUNNEL;
0967 
0968     if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
0969         return -1;
0970 
0971     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
0972         pr_err("send()");
0973         return -1;
0974     }
0975 
0976     return netlink_check_answer(xfrm_sock);
0977 }
0978 
0979 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
0980         struct in_addr src, struct in_addr dst,
0981         struct xfrm_desc *desc)
0982 {
0983     if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
0984         return false;
0985 
0986     if (memcmp(&info->sel.saddr, &src, sizeof(src)))
0987         return false;
0988 
0989     if (info->sel.family != AF_INET                 ||
0990             info->sel.prefixlen_d != PREFIX_LEN     ||
0991             info->sel.prefixlen_s != PREFIX_LEN)
0992         return false;
0993 
0994     if (info->id.spi != spi || info->id.proto != desc->proto)
0995         return false;
0996 
0997     if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
0998         return false;
0999 
1000     if (memcmp(&info->saddr, &src, sizeof(src)))
1001         return false;
1002 
1003     if (info->lft.soft_byte_limit != XFRM_INF           ||
1004             info->lft.hard_byte_limit != XFRM_INF       ||
1005             info->lft.soft_packet_limit != XFRM_INF     ||
1006             info->lft.hard_packet_limit != XFRM_INF)
1007         return false;
1008 
1009     if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1010         return false;
1011 
1012     /* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1013 
1014     return true;
1015 }
1016 
1017 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1018         struct in_addr src, struct in_addr dst,
1019         struct xfrm_desc *desc)
1020 {
1021     struct {
1022         struct nlmsghdr     nh;
1023         char            attrbuf[MAX_PAYLOAD];
1024     } req;
1025     struct {
1026         struct nlmsghdr     nh;
1027         union {
1028             struct xfrm_usersa_info info;
1029             int error;
1030         };
1031         char            attrbuf[MAX_PAYLOAD];
1032     } answer;
1033     struct xfrm_address_filter filter = {};
1034     bool found = false;
1035 
1036 
1037     memset(&req, 0, sizeof(req));
1038     req.nh.nlmsg_len    = NLMSG_LENGTH(0);
1039     req.nh.nlmsg_type   = XFRM_MSG_GETSA;
1040     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_DUMP;
1041     req.nh.nlmsg_seq    = seq;
1042 
1043     /*
1044      * Add dump filter by source address as there may be other tunnels
1045      * in this netns (if tests run in parallel).
1046      */
1047     filter.family = AF_INET;
1048     filter.splen = 0x1f;    /* 0xffffffff mask see addr_match() */
1049     memcpy(&filter.saddr, &src, sizeof(src));
1050     if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1051                 &filter, sizeof(filter)))
1052         return -1;
1053 
1054     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1055         pr_err("send()");
1056         return -1;
1057     }
1058 
1059     while (1) {
1060         if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1061             pr_err("recv()");
1062             return -1;
1063         }
1064         if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1065             printk("NLMSG_ERROR: %d: %s",
1066                 answer.error, strerror(-answer.error));
1067             return -1;
1068         } else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1069             if (found)
1070                 return 0;
1071             printk("didn't find allocated xfrm state in dump");
1072             return -1;
1073         } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1074             if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1075                 found = true;
1076         }
1077     }
1078 }
1079 
1080 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1081         struct in_addr src, struct in_addr dst,
1082         struct in_addr tunsrc, struct in_addr tundst,
1083         struct xfrm_desc *desc)
1084 {
1085     int err;
1086 
1087     err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1088     if (err) {
1089         printk("Failed to add xfrm state");
1090         return -1;
1091     }
1092 
1093     err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1094     if (err) {
1095         printk("Failed to add xfrm state");
1096         return -1;
1097     }
1098 
1099     /* Check dumps for XFRM_MSG_GETSA */
1100     err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1101     err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1102     if (err) {
1103         printk("Failed to check xfrm state");
1104         return -1;
1105     }
1106 
1107     return 0;
1108 }
1109 
1110 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1111         struct in_addr src, struct in_addr dst, uint8_t dir,
1112         struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1113 {
1114     struct {
1115         struct nlmsghdr         nh;
1116         struct xfrm_userpolicy_info info;
1117         char                attrbuf[MAX_PAYLOAD];
1118     } req;
1119     struct xfrm_user_tmpl tmpl;
1120 
1121     memset(&req, 0, sizeof(req));
1122     memset(&tmpl, 0, sizeof(tmpl));
1123     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.info));
1124     req.nh.nlmsg_type   = XFRM_MSG_NEWPOLICY;
1125     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1126     req.nh.nlmsg_seq    = seq;
1127 
1128     /* Fill selector. */
1129     memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1130     memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1131     req.info.sel.family     = AF_INET;
1132     req.info.sel.prefixlen_d    = PREFIX_LEN;
1133     req.info.sel.prefixlen_s    = PREFIX_LEN;
1134 
1135     /* Fill lifteme_cfg */
1136     req.info.lft.soft_byte_limit    = XFRM_INF;
1137     req.info.lft.hard_byte_limit    = XFRM_INF;
1138     req.info.lft.soft_packet_limit  = XFRM_INF;
1139     req.info.lft.hard_packet_limit  = XFRM_INF;
1140 
1141     req.info.dir = dir;
1142 
1143     /* Fill tmpl */
1144     memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1145     /* Note: zero-spi cannot be deleted */
1146     tmpl.id.spi = spi;
1147     tmpl.id.proto   = proto;
1148     tmpl.family = AF_INET;
1149     memcpy(&tmpl.saddr, &src, sizeof(src));
1150     tmpl.mode   = XFRM_MODE_TUNNEL;
1151     tmpl.aalgos = (~(uint32_t)0);
1152     tmpl.ealgos = (~(uint32_t)0);
1153     tmpl.calgos = (~(uint32_t)0);
1154 
1155     if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1156         return -1;
1157 
1158     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1159         pr_err("send()");
1160         return -1;
1161     }
1162 
1163     return netlink_check_answer(xfrm_sock);
1164 }
1165 
1166 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1167         struct in_addr src, struct in_addr dst,
1168         struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1169 {
1170     if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1171                 XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1172         printk("Failed to add xfrm policy");
1173         return -1;
1174     }
1175 
1176     if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1177                 XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1178         printk("Failed to add xfrm policy");
1179         return -1;
1180     }
1181 
1182     return 0;
1183 }
1184 
1185 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1186         struct in_addr src, struct in_addr dst, uint8_t dir,
1187         struct in_addr tunsrc, struct in_addr tundst)
1188 {
1189     struct {
1190         struct nlmsghdr         nh;
1191         struct xfrm_userpolicy_id   id;
1192         char                attrbuf[MAX_PAYLOAD];
1193     } req;
1194 
1195     memset(&req, 0, sizeof(req));
1196     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.id));
1197     req.nh.nlmsg_type   = XFRM_MSG_DELPOLICY;
1198     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1199     req.nh.nlmsg_seq    = seq;
1200 
1201     /* Fill id */
1202     memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1203     memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1204     req.id.sel.family       = AF_INET;
1205     req.id.sel.prefixlen_d      = PREFIX_LEN;
1206     req.id.sel.prefixlen_s      = PREFIX_LEN;
1207     req.id.dir = dir;
1208 
1209     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1210         pr_err("send()");
1211         return -1;
1212     }
1213 
1214     return netlink_check_answer(xfrm_sock);
1215 }
1216 
1217 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1218         struct in_addr src, struct in_addr dst,
1219         struct in_addr tunsrc, struct in_addr tundst)
1220 {
1221     if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1222                 XFRM_POLICY_OUT, tunsrc, tundst)) {
1223         printk("Failed to add xfrm policy");
1224         return -1;
1225     }
1226 
1227     if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1228                 XFRM_POLICY_IN, tunsrc, tundst)) {
1229         printk("Failed to add xfrm policy");
1230         return -1;
1231     }
1232 
1233     return 0;
1234 }
1235 
1236 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1237         struct in_addr src, struct in_addr dst, uint8_t proto)
1238 {
1239     struct {
1240         struct nlmsghdr     nh;
1241         struct xfrm_usersa_id   id;
1242         char            attrbuf[MAX_PAYLOAD];
1243     } req;
1244     xfrm_address_t saddr = {};
1245 
1246     memset(&req, 0, sizeof(req));
1247     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.id));
1248     req.nh.nlmsg_type   = XFRM_MSG_DELSA;
1249     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1250     req.nh.nlmsg_seq    = seq;
1251 
1252     memcpy(&req.id.daddr, &dst, sizeof(dst));
1253     req.id.family       = AF_INET;
1254     req.id.proto        = proto;
1255     /* Note: zero-spi cannot be deleted */
1256     req.id.spi = spi;
1257 
1258     memcpy(&saddr, &src, sizeof(src));
1259     if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1260         return -1;
1261 
1262     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1263         pr_err("send()");
1264         return -1;
1265     }
1266 
1267     return netlink_check_answer(xfrm_sock);
1268 }
1269 
1270 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1271         struct in_addr src, struct in_addr dst,
1272         struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1273 {
1274     if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1275         printk("Failed to remove xfrm state");
1276         return -1;
1277     }
1278 
1279     if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1280         printk("Failed to remove xfrm state");
1281         return -1;
1282     }
1283 
1284     return 0;
1285 }
1286 
1287 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1288         uint32_t spi, uint8_t proto)
1289 {
1290     struct {
1291         struct nlmsghdr         nh;
1292         struct xfrm_userspi_info    spi;
1293     } req;
1294     struct {
1295         struct nlmsghdr         nh;
1296         union {
1297             struct xfrm_usersa_info info;
1298             int error;
1299         };
1300     } answer;
1301 
1302     memset(&req, 0, sizeof(req));
1303     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.spi));
1304     req.nh.nlmsg_type   = XFRM_MSG_ALLOCSPI;
1305     req.nh.nlmsg_flags  = NLM_F_REQUEST;
1306     req.nh.nlmsg_seq    = (*seq)++;
1307 
1308     req.spi.info.family = AF_INET;
1309     req.spi.min     = spi;
1310     req.spi.max     = spi;
1311     req.spi.info.id.proto   = proto;
1312 
1313     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1314         pr_err("send()");
1315         return KSFT_FAIL;
1316     }
1317 
1318     if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1319         pr_err("recv()");
1320         return KSFT_FAIL;
1321     } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1322         uint32_t new_spi = htonl(answer.info.id.spi);
1323 
1324         if (new_spi != spi) {
1325             printk("allocated spi is different from requested: %#x != %#x",
1326                     new_spi, spi);
1327             return KSFT_FAIL;
1328         }
1329         return KSFT_PASS;
1330     } else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1331         printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1332         return KSFT_FAIL;
1333     }
1334 
1335     printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1336     return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1337 }
1338 
1339 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1340 {
1341     struct sockaddr_nl snl = {};
1342     socklen_t addr_len;
1343     int ret = -1;
1344 
1345     snl.nl_family = AF_NETLINK;
1346     snl.nl_groups = groups;
1347 
1348     if (netlink_sock(sock, seq, proto)) {
1349         printk("Failed to open xfrm netlink socket");
1350         return -1;
1351     }
1352 
1353     if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1354         pr_err("bind()");
1355         goto out_close;
1356     }
1357 
1358     addr_len = sizeof(snl);
1359     if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1360         pr_err("getsockname()");
1361         goto out_close;
1362     }
1363     if (addr_len != sizeof(snl)) {
1364         printk("Wrong address length %d", addr_len);
1365         goto out_close;
1366     }
1367     if (snl.nl_family != AF_NETLINK) {
1368         printk("Wrong address family %d", snl.nl_family);
1369         goto out_close;
1370     }
1371     return 0;
1372 
1373 out_close:
1374     close(*sock);
1375     return ret;
1376 }
1377 
1378 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1379 {
1380     struct {
1381         struct nlmsghdr nh;
1382         union {
1383             struct xfrm_user_acquire acq;
1384             int error;
1385         };
1386         char attrbuf[MAX_PAYLOAD];
1387     } req;
1388     struct xfrm_user_tmpl xfrm_tmpl = {};
1389     int xfrm_listen = -1, ret = KSFT_FAIL;
1390     uint32_t seq_listen;
1391 
1392     if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1393         return KSFT_FAIL;
1394 
1395     memset(&req, 0, sizeof(req));
1396     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.acq));
1397     req.nh.nlmsg_type   = XFRM_MSG_ACQUIRE;
1398     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1399     req.nh.nlmsg_seq    = (*seq)++;
1400 
1401     req.acq.policy.sel.family   = AF_INET;
1402     req.acq.aalgos  = 0xfeed;
1403     req.acq.ealgos  = 0xbaad;
1404     req.acq.calgos  = 0xbabe;
1405 
1406     xfrm_tmpl.family = AF_INET;
1407     xfrm_tmpl.id.proto = IPPROTO_ESP;
1408     if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1409         goto out_close;
1410 
1411     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1412         pr_err("send()");
1413         goto out_close;
1414     }
1415 
1416     if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1417         pr_err("recv()");
1418         goto out_close;
1419     } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1420         printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1421         goto out_close;
1422     }
1423 
1424     if (req.error) {
1425         printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1426         ret = req.error;
1427         goto out_close;
1428     }
1429 
1430     if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1431         pr_err("recv()");
1432         goto out_close;
1433     }
1434 
1435     if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1436             || req.acq.calgos != 0xbabe) {
1437         printk("xfrm_user_acquire has changed  %x %x %x",
1438                 req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1439         goto out_close;
1440     }
1441 
1442     ret = KSFT_PASS;
1443 out_close:
1444     close(xfrm_listen);
1445     return ret;
1446 }
1447 
1448 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1449         unsigned int nr, struct xfrm_desc *desc)
1450 {
1451     struct {
1452         struct nlmsghdr nh;
1453         union {
1454             struct xfrm_user_expire expire;
1455             int error;
1456         };
1457     } req;
1458     struct in_addr src, dst;
1459     int xfrm_listen = -1, ret = KSFT_FAIL;
1460     uint32_t seq_listen;
1461 
1462     src = inet_makeaddr(INADDR_B, child_ip(nr));
1463     dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1464 
1465     if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1466         printk("Failed to add xfrm state");
1467         return KSFT_FAIL;
1468     }
1469 
1470     if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1471         return KSFT_FAIL;
1472 
1473     memset(&req, 0, sizeof(req));
1474     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.expire));
1475     req.nh.nlmsg_type   = XFRM_MSG_EXPIRE;
1476     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1477     req.nh.nlmsg_seq    = (*seq)++;
1478 
1479     memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1480     req.expire.state.id.spi     = gen_spi(src);
1481     req.expire.state.id.proto   = desc->proto;
1482     req.expire.state.family     = AF_INET;
1483     req.expire.hard         = 0xff;
1484 
1485     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1486         pr_err("send()");
1487         goto out_close;
1488     }
1489 
1490     if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1491         pr_err("recv()");
1492         goto out_close;
1493     } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1494         printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1495         goto out_close;
1496     }
1497 
1498     if (req.error) {
1499         printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1500         ret = req.error;
1501         goto out_close;
1502     }
1503 
1504     if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1505         pr_err("recv()");
1506         goto out_close;
1507     }
1508 
1509     if (req.expire.hard != 0x1) {
1510         printk("expire.hard is not set: %x", req.expire.hard);
1511         goto out_close;
1512     }
1513 
1514     ret = KSFT_PASS;
1515 out_close:
1516     close(xfrm_listen);
1517     return ret;
1518 }
1519 
1520 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1521         unsigned int nr, struct xfrm_desc *desc)
1522 {
1523     struct {
1524         struct nlmsghdr nh;
1525         union {
1526             struct xfrm_user_polexpire expire;
1527             int error;
1528         };
1529     } req;
1530     struct in_addr src, dst, tunsrc, tundst;
1531     int xfrm_listen = -1, ret = KSFT_FAIL;
1532     uint32_t seq_listen;
1533 
1534     src = inet_makeaddr(INADDR_B, child_ip(nr));
1535     dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1536     tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1537     tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1538 
1539     if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1540                 XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1541         printk("Failed to add xfrm policy");
1542         return KSFT_FAIL;
1543     }
1544 
1545     if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1546         return KSFT_FAIL;
1547 
1548     memset(&req, 0, sizeof(req));
1549     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.expire));
1550     req.nh.nlmsg_type   = XFRM_MSG_POLEXPIRE;
1551     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1552     req.nh.nlmsg_seq    = (*seq)++;
1553 
1554     /* Fill selector. */
1555     memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1556     memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1557     req.expire.pol.sel.family   = AF_INET;
1558     req.expire.pol.sel.prefixlen_d  = PREFIX_LEN;
1559     req.expire.pol.sel.prefixlen_s  = PREFIX_LEN;
1560     req.expire.pol.dir      = XFRM_POLICY_OUT;
1561     req.expire.hard         = 0xff;
1562 
1563     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1564         pr_err("send()");
1565         goto out_close;
1566     }
1567 
1568     if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1569         pr_err("recv()");
1570         goto out_close;
1571     } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1572         printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1573         goto out_close;
1574     }
1575 
1576     if (req.error) {
1577         printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1578         ret = req.error;
1579         goto out_close;
1580     }
1581 
1582     if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1583         pr_err("recv()");
1584         goto out_close;
1585     }
1586 
1587     if (req.expire.hard != 0x1) {
1588         printk("expire.hard is not set: %x", req.expire.hard);
1589         goto out_close;
1590     }
1591 
1592     ret = KSFT_PASS;
1593 out_close:
1594     close(xfrm_listen);
1595     return ret;
1596 }
1597 
1598 static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1599         unsigned thresh4_l, unsigned thresh4_r,
1600         unsigned thresh6_l, unsigned thresh6_r,
1601         bool add_bad_attr)
1602 
1603 {
1604     struct {
1605         struct nlmsghdr     nh;
1606         union {
1607             uint32_t    unused;
1608             int     error;
1609         };
1610         char            attrbuf[MAX_PAYLOAD];
1611     } req;
1612     struct xfrmu_spdhthresh thresh;
1613 
1614     memset(&req, 0, sizeof(req));
1615     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.unused));
1616     req.nh.nlmsg_type   = XFRM_MSG_NEWSPDINFO;
1617     req.nh.nlmsg_flags  = NLM_F_REQUEST | NLM_F_ACK;
1618     req.nh.nlmsg_seq    = (*seq)++;
1619 
1620     thresh.lbits = thresh4_l;
1621     thresh.rbits = thresh4_r;
1622     if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1623         return -1;
1624 
1625     thresh.lbits = thresh6_l;
1626     thresh.rbits = thresh6_r;
1627     if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1628         return -1;
1629 
1630     if (add_bad_attr) {
1631         BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1632         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1633             pr_err("adding attribute failed: no space");
1634             return -1;
1635         }
1636     }
1637 
1638     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1639         pr_err("send()");
1640         return -1;
1641     }
1642 
1643     if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1644         pr_err("recv()");
1645         return -1;
1646     } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1647         printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1648         return -1;
1649     }
1650 
1651     if (req.error) {
1652         printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1653         return -1;
1654     }
1655 
1656     return 0;
1657 }
1658 
1659 static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1660 {
1661     struct {
1662         struct nlmsghdr         nh;
1663         union {
1664             uint32_t    unused;
1665             int     error;
1666         };
1667         char            attrbuf[MAX_PAYLOAD];
1668     } req;
1669 
1670     if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1671         pr_err("Can't set SPD HTHRESH");
1672         return KSFT_FAIL;
1673     }
1674 
1675     memset(&req, 0, sizeof(req));
1676 
1677     req.nh.nlmsg_len    = NLMSG_LENGTH(sizeof(req.unused));
1678     req.nh.nlmsg_type   = XFRM_MSG_GETSPDINFO;
1679     req.nh.nlmsg_flags  = NLM_F_REQUEST;
1680     req.nh.nlmsg_seq    = (*seq)++;
1681     if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1682         pr_err("send()");
1683         return KSFT_FAIL;
1684     }
1685 
1686     if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1687         pr_err("recv()");
1688         return KSFT_FAIL;
1689     } else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1690         size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1691         struct rtattr *attr = (void *)req.attrbuf;
1692         int got_thresh = 0;
1693 
1694         for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1695             if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1696                 struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1697 
1698                 got_thresh++;
1699                 if (t->lbits != 32 || t->rbits != 31) {
1700                     pr_err("thresh differ: %u, %u",
1701                             t->lbits, t->rbits);
1702                     return KSFT_FAIL;
1703                 }
1704             }
1705             if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1706                 struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1707 
1708                 got_thresh++;
1709                 if (t->lbits != 120 || t->rbits != 16) {
1710                     pr_err("thresh differ: %u, %u",
1711                             t->lbits, t->rbits);
1712                     return KSFT_FAIL;
1713                 }
1714             }
1715         }
1716         if (got_thresh != 2) {
1717             pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1718             return KSFT_FAIL;
1719         }
1720     } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1721         printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1722         return KSFT_FAIL;
1723     } else {
1724         printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1725         return -1;
1726     }
1727 
1728     /* Restore the default */
1729     if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1730         pr_err("Can't restore SPD HTHRESH");
1731         return KSFT_FAIL;
1732     }
1733 
1734     /*
1735      * At this moment xfrm uses nlmsg_parse_deprecated(), which
1736      * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1737      * (type > maxtype). nla_parse_depricated_strict() would enforce
1738      * it. Or even stricter nla_parse().
1739      * Right now it's not expected to fail, but to be ignored.
1740      */
1741     if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1742         return KSFT_PASS;
1743 
1744     return KSFT_PASS;
1745 }
1746 
1747 static int child_serv(int xfrm_sock, uint32_t *seq,
1748         unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1749 {
1750     struct in_addr src, dst, tunsrc, tundst;
1751     struct test_desc msg;
1752     int ret = KSFT_FAIL;
1753 
1754     src = inet_makeaddr(INADDR_B, child_ip(nr));
1755     dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1756     tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1757     tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1758 
1759     /* UDP pinging without xfrm */
1760     if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1761         printk("ping failed before setting xfrm");
1762         return KSFT_FAIL;
1763     }
1764 
1765     memset(&msg, 0, sizeof(msg));
1766     msg.type = MSG_XFRM_PREPARE;
1767     memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1768     write_msg(cmd_fd, &msg, 1);
1769 
1770     if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1771         printk("failed to prepare xfrm");
1772         goto cleanup;
1773     }
1774 
1775     memset(&msg, 0, sizeof(msg));
1776     msg.type = MSG_XFRM_ADD;
1777     memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1778     write_msg(cmd_fd, &msg, 1);
1779     if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1780         printk("failed to set xfrm");
1781         goto delete;
1782     }
1783 
1784     /* UDP pinging with xfrm tunnel */
1785     if (do_ping(cmd_fd, buf, page_size, tunsrc,
1786                 true, 0, 0, udp_ping_send)) {
1787         printk("ping failed for xfrm");
1788         goto delete;
1789     }
1790 
1791     ret = KSFT_PASS;
1792 delete:
1793     /* xfrm delete */
1794     memset(&msg, 0, sizeof(msg));
1795     msg.type = MSG_XFRM_DEL;
1796     memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1797     write_msg(cmd_fd, &msg, 1);
1798 
1799     if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1800         printk("failed ping to remove xfrm");
1801         ret = KSFT_FAIL;
1802     }
1803 
1804 cleanup:
1805     memset(&msg, 0, sizeof(msg));
1806     msg.type = MSG_XFRM_CLEANUP;
1807     memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1808     write_msg(cmd_fd, &msg, 1);
1809     if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1810         printk("failed ping to cleanup xfrm");
1811         ret = KSFT_FAIL;
1812     }
1813     return ret;
1814 }
1815 
1816 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1817 {
1818     struct xfrm_desc desc;
1819     struct test_desc msg;
1820     int xfrm_sock = -1;
1821     uint32_t seq;
1822 
1823     if (switch_ns(nsfd_childa))
1824         exit(KSFT_FAIL);
1825 
1826     if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1827         printk("Failed to open xfrm netlink socket");
1828         exit(KSFT_FAIL);
1829     }
1830 
1831     /* Check that seq sock is ready, just for sure. */
1832     memset(&msg, 0, sizeof(msg));
1833     msg.type = MSG_ACK;
1834     write_msg(cmd_fd, &msg, 1);
1835     read_msg(cmd_fd, &msg, 1);
1836     if (msg.type != MSG_ACK) {
1837         printk("Ack failed");
1838         exit(KSFT_FAIL);
1839     }
1840 
1841     for (;;) {
1842         ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1843         int ret;
1844 
1845         if (received == 0) /* EOF */
1846             break;
1847 
1848         if (received != sizeof(desc)) {
1849             pr_err("read() returned %zd", received);
1850             exit(KSFT_FAIL);
1851         }
1852 
1853         switch (desc.type) {
1854         case CREATE_TUNNEL:
1855             ret = child_serv(xfrm_sock, &seq, nr,
1856                      cmd_fd, buf, &desc);
1857             break;
1858         case ALLOCATE_SPI:
1859             ret = xfrm_state_allocspi(xfrm_sock, &seq,
1860                           -1, desc.proto);
1861             break;
1862         case MONITOR_ACQUIRE:
1863             ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1864             break;
1865         case EXPIRE_STATE:
1866             ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1867             break;
1868         case EXPIRE_POLICY:
1869             ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1870             break;
1871         case SPDINFO_ATTRS:
1872             ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1873             break;
1874         default:
1875             printk("Unknown desc type %d", desc.type);
1876             exit(KSFT_FAIL);
1877         }
1878         write_test_result(ret, &desc);
1879     }
1880 
1881     close(xfrm_sock);
1882 
1883     msg.type = MSG_EXIT;
1884     write_msg(cmd_fd, &msg, 1);
1885     exit(KSFT_PASS);
1886 }
1887 
1888 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1889         struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1890 {
1891     struct in_addr src, dst, tunsrc, tundst;
1892     bool tun_reply;
1893     struct xfrm_desc *desc = &msg->body.xfrm_desc;
1894 
1895     src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1896     dst = inet_makeaddr(INADDR_B, child_ip(nr));
1897     tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1898     tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1899 
1900     switch (msg->type) {
1901     case MSG_EXIT:
1902         exit(KSFT_PASS);
1903     case MSG_ACK:
1904         write_msg(cmd_fd, msg, 1);
1905         break;
1906     case MSG_PING:
1907         tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1908         /* UDP pinging without xfrm */
1909         if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1910                 false, msg->body.ping.port,
1911                 msg->body.ping.reply_ip, udp_ping_reply)) {
1912             printk("ping failed before setting xfrm");
1913         }
1914         break;
1915     case MSG_XFRM_PREPARE:
1916         if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1917                     desc->proto)) {
1918             xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1919             printk("failed to prepare xfrm");
1920         }
1921         break;
1922     case MSG_XFRM_ADD:
1923         if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1924             xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1925             printk("failed to set xfrm");
1926         }
1927         break;
1928     case MSG_XFRM_DEL:
1929         if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1930                     desc->proto)) {
1931             xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1932             printk("failed to remove xfrm");
1933         }
1934         break;
1935     case MSG_XFRM_CLEANUP:
1936         if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1937             printk("failed to cleanup xfrm");
1938         }
1939         break;
1940     default:
1941         printk("got unknown msg type %d", msg->type);
1942     }
1943 }
1944 
1945 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1946 {
1947     struct test_desc msg;
1948     int xfrm_sock = -1;
1949     uint32_t seq;
1950 
1951     if (switch_ns(nsfd_childb))
1952         exit(KSFT_FAIL);
1953 
1954     if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1955         printk("Failed to open xfrm netlink socket");
1956         exit(KSFT_FAIL);
1957     }
1958 
1959     do {
1960         read_msg(cmd_fd, &msg, 1);
1961         grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1962     } while (1);
1963 
1964     close(xfrm_sock);
1965     exit(KSFT_FAIL);
1966 }
1967 
1968 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1969 {
1970     int cmd_sock[2];
1971     void *data_map;
1972     pid_t child;
1973 
1974     if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1975         return -1;
1976 
1977     if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1978         return -1;
1979 
1980     child = fork();
1981     if (child < 0) {
1982         pr_err("fork()");
1983         return -1;
1984     } else if (child) {
1985         /* in parent - selftest */
1986         return switch_ns(nsfd_parent);
1987     }
1988 
1989     if (close(test_desc_fd[1])) {
1990         pr_err("close()");
1991         return -1;
1992     }
1993 
1994     /* child */
1995     data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1996             MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1997     if (data_map == MAP_FAILED) {
1998         pr_err("mmap()");
1999         return -1;
2000     }
2001 
2002     randomize_buffer(data_map, page_size);
2003 
2004     if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
2005         pr_err("socketpair()");
2006         return -1;
2007     }
2008 
2009     child = fork();
2010     if (child < 0) {
2011         pr_err("fork()");
2012         return -1;
2013     } else if (child) {
2014         if (close(cmd_sock[0])) {
2015             pr_err("close()");
2016             return -1;
2017         }
2018         return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2019     }
2020     if (close(cmd_sock[1])) {
2021         pr_err("close()");
2022         return -1;
2023     }
2024     return grand_child_f(nr, cmd_sock[0], data_map);
2025 }
2026 
2027 static void exit_usage(char **argv)
2028 {
2029     printk("Usage: %s [nr_process]", argv[0]);
2030     exit(KSFT_FAIL);
2031 }
2032 
2033 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2034 {
2035     ssize_t ret;
2036 
2037     ret = write(test_desc_fd, desc, sizeof(*desc));
2038 
2039     if (ret == sizeof(*desc))
2040         return 0;
2041 
2042     pr_err("Writing test's desc failed %ld", ret);
2043 
2044     return -1;
2045 }
2046 
2047 static int write_desc(int proto, int test_desc_fd,
2048         char *a, char *e, char *c, char *ae)
2049 {
2050     struct xfrm_desc desc = {};
2051 
2052     desc.type = CREATE_TUNNEL;
2053     desc.proto = proto;
2054 
2055     if (a)
2056         strncpy(desc.a_algo, a, ALGO_LEN - 1);
2057     if (e)
2058         strncpy(desc.e_algo, e, ALGO_LEN - 1);
2059     if (c)
2060         strncpy(desc.c_algo, c, ALGO_LEN - 1);
2061     if (ae)
2062         strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2063 
2064     return __write_desc(test_desc_fd, &desc);
2065 }
2066 
2067 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2068 char *ah_list[] = {
2069     "digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2070     "hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2071     "xcbc(aes)", "cmac(aes)"
2072 };
2073 char *comp_list[] = {
2074     "deflate",
2075 #if 0
2076     /* No compression backend realization */
2077     "lzs", "lzjh"
2078 #endif
2079 };
2080 char *e_list[] = {
2081     "ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2082     "cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2083     "cbc(twofish)", "rfc3686(ctr(aes))"
2084 };
2085 char *ae_list[] = {
2086 #if 0
2087     /* not implemented */
2088     "rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2089     "rfc7539esp(chacha20,poly1305)"
2090 #endif
2091 };
2092 
2093 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2094                 + (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2095                 + ARRAY_SIZE(ae_list);
2096 
2097 static int write_proto_plan(int fd, int proto)
2098 {
2099     unsigned int i;
2100 
2101     switch (proto) {
2102     case IPPROTO_AH:
2103         for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2104             if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2105                 return -1;
2106         }
2107         break;
2108     case IPPROTO_COMP:
2109         for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2110             if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2111                 return -1;
2112         }
2113         break;
2114     case IPPROTO_ESP:
2115         for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2116             int j;
2117 
2118             for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2119                 if (write_desc(proto, fd, ah_list[i],
2120                             e_list[j], 0, 0))
2121                     return -1;
2122             }
2123         }
2124         for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2125             if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2126                 return -1;
2127         }
2128         break;
2129     default:
2130         printk("BUG: Specified unknown proto %d", proto);
2131         return -1;
2132     }
2133 
2134     return 0;
2135 }
2136 
2137 /*
2138  * Some structures in xfrm uapi header differ in size between
2139  * 64-bit and 32-bit ABI:
2140  *
2141  *             32-bit UABI               |            64-bit UABI
2142  *  -------------------------------------|-------------------------------------
2143  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2144  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2145  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2146  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2147  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2148  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2149  *
2150  * Check the affected by the UABI difference structures.
2151  * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2152  * which needs to be correctly copied, but not translated.
2153  */
2154 const unsigned int compat_plan = 5;
2155 static int write_compat_struct_tests(int test_desc_fd)
2156 {
2157     struct xfrm_desc desc = {};
2158 
2159     desc.type = ALLOCATE_SPI;
2160     desc.proto = IPPROTO_AH;
2161     strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2162 
2163     if (__write_desc(test_desc_fd, &desc))
2164         return -1;
2165 
2166     desc.type = MONITOR_ACQUIRE;
2167     if (__write_desc(test_desc_fd, &desc))
2168         return -1;
2169 
2170     desc.type = EXPIRE_STATE;
2171     if (__write_desc(test_desc_fd, &desc))
2172         return -1;
2173 
2174     desc.type = EXPIRE_POLICY;
2175     if (__write_desc(test_desc_fd, &desc))
2176         return -1;
2177 
2178     desc.type = SPDINFO_ATTRS;
2179     if (__write_desc(test_desc_fd, &desc))
2180         return -1;
2181 
2182     return 0;
2183 }
2184 
2185 static int write_test_plan(int test_desc_fd)
2186 {
2187     unsigned int i;
2188     pid_t child;
2189 
2190     child = fork();
2191     if (child < 0) {
2192         pr_err("fork()");
2193         return -1;
2194     }
2195     if (child) {
2196         if (close(test_desc_fd))
2197             printk("close(): %m");
2198         return 0;
2199     }
2200 
2201     if (write_compat_struct_tests(test_desc_fd))
2202         exit(KSFT_FAIL);
2203 
2204     for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2205         if (write_proto_plan(test_desc_fd, proto_list[i]))
2206             exit(KSFT_FAIL);
2207     }
2208 
2209     exit(KSFT_PASS);
2210 }
2211 
2212 static int children_cleanup(void)
2213 {
2214     unsigned ret = KSFT_PASS;
2215 
2216     while (1) {
2217         int status;
2218         pid_t p = wait(&status);
2219 
2220         if ((p < 0) && errno == ECHILD)
2221             break;
2222 
2223         if (p < 0) {
2224             pr_err("wait()");
2225             return KSFT_FAIL;
2226         }
2227 
2228         if (!WIFEXITED(status)) {
2229             ret = KSFT_FAIL;
2230             continue;
2231         }
2232 
2233         if (WEXITSTATUS(status) == KSFT_FAIL)
2234             ret = KSFT_FAIL;
2235     }
2236 
2237     return ret;
2238 }
2239 
2240 typedef void (*print_res)(const char *, ...);
2241 
2242 static int check_results(void)
2243 {
2244     struct test_result tr = {};
2245     struct xfrm_desc *d = &tr.desc;
2246     int ret = KSFT_PASS;
2247 
2248     while (1) {
2249         ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2250         print_res result;
2251 
2252         if (received == 0) /* EOF */
2253             break;
2254 
2255         if (received != sizeof(tr)) {
2256             pr_err("read() returned %zd", received);
2257             return KSFT_FAIL;
2258         }
2259 
2260         switch (tr.res) {
2261         case KSFT_PASS:
2262             result = ksft_test_result_pass;
2263             break;
2264         case KSFT_FAIL:
2265         default:
2266             result = ksft_test_result_fail;
2267             ret = KSFT_FAIL;
2268         }
2269 
2270         result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2271                desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2272                d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2273     }
2274 
2275     return ret;
2276 }
2277 
2278 int main(int argc, char **argv)
2279 {
2280     unsigned int nr_process = 1;
2281     int route_sock = -1, ret = KSFT_SKIP;
2282     int test_desc_fd[2];
2283     uint32_t route_seq;
2284     unsigned int i;
2285 
2286     if (argc > 2)
2287         exit_usage(argv);
2288 
2289     if (argc > 1) {
2290         char *endptr;
2291 
2292         errno = 0;
2293         nr_process = strtol(argv[1], &endptr, 10);
2294         if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2295                 || (errno != 0 && nr_process == 0)
2296                 || (endptr == argv[1]) || (*endptr != '\0')) {
2297             printk("Failed to parse [nr_process]");
2298             exit_usage(argv);
2299         }
2300 
2301         if (nr_process > MAX_PROCESSES || !nr_process) {
2302             printk("nr_process should be between [1; %u]",
2303                     MAX_PROCESSES);
2304             exit_usage(argv);
2305         }
2306     }
2307 
2308     srand(time(NULL));
2309     page_size = sysconf(_SC_PAGESIZE);
2310     if (page_size < 1)
2311         ksft_exit_skip("sysconf(): %m\n");
2312 
2313     if (pipe2(test_desc_fd, O_DIRECT) < 0)
2314         ksft_exit_skip("pipe(): %m\n");
2315 
2316     if (pipe2(results_fd, O_DIRECT) < 0)
2317         ksft_exit_skip("pipe(): %m\n");
2318 
2319     if (init_namespaces())
2320         ksft_exit_skip("Failed to create namespaces\n");
2321 
2322     if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2323         ksft_exit_skip("Failed to open netlink route socket\n");
2324 
2325     for (i = 0; i < nr_process; i++) {
2326         char veth[VETH_LEN];
2327 
2328         snprintf(veth, VETH_LEN, VETH_FMT, i);
2329 
2330         if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2331             close(route_sock);
2332             ksft_exit_fail_msg("Failed to create veth device");
2333         }
2334 
2335         if (start_child(i, veth, test_desc_fd)) {
2336             close(route_sock);
2337             ksft_exit_fail_msg("Child %u failed to start", i);
2338         }
2339     }
2340 
2341     if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2342         ksft_exit_fail_msg("close(): %m");
2343 
2344     ksft_set_plan(proto_plan + compat_plan);
2345 
2346     if (write_test_plan(test_desc_fd[1]))
2347         ksft_exit_fail_msg("Failed to write test plan to pipe");
2348 
2349     ret = check_results();
2350 
2351     if (children_cleanup() == KSFT_FAIL)
2352         exit(KSFT_FAIL);
2353 
2354     exit(ret);
2355 }