Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /* Copyright (C) 2009 Red Hat, Inc.
0003  * Author: Michael S. Tsirkin <mst@redhat.com>
0004  *
0005  * virtio-net server in host kernel.
0006  */
0007 
0008 #include <linux/compat.h>
0009 #include <linux/eventfd.h>
0010 #include <linux/vhost.h>
0011 #include <linux/virtio_net.h>
0012 #include <linux/miscdevice.h>
0013 #include <linux/module.h>
0014 #include <linux/moduleparam.h>
0015 #include <linux/mutex.h>
0016 #include <linux/workqueue.h>
0017 #include <linux/file.h>
0018 #include <linux/slab.h>
0019 #include <linux/sched/clock.h>
0020 #include <linux/sched/signal.h>
0021 #include <linux/vmalloc.h>
0022 
0023 #include <linux/net.h>
0024 #include <linux/if_packet.h>
0025 #include <linux/if_arp.h>
0026 #include <linux/if_tun.h>
0027 #include <linux/if_macvlan.h>
0028 #include <linux/if_tap.h>
0029 #include <linux/if_vlan.h>
0030 #include <linux/skb_array.h>
0031 #include <linux/skbuff.h>
0032 
0033 #include <net/sock.h>
0034 #include <net/xdp.h>
0035 
0036 #include "vhost.h"
0037 
0038 static int experimental_zcopytx = 0;
0039 module_param(experimental_zcopytx, int, 0444);
0040 MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
0041                                " 1 -Enable; 0 - Disable");
0042 
0043 /* Max number of bytes transferred before requeueing the job.
0044  * Using this limit prevents one virtqueue from starving others. */
0045 #define VHOST_NET_WEIGHT 0x80000
0046 
0047 /* Max number of packets transferred before requeueing the job.
0048  * Using this limit prevents one virtqueue from starving others with small
0049  * pkts.
0050  */
0051 #define VHOST_NET_PKT_WEIGHT 256
0052 
0053 /* MAX number of TX used buffers for outstanding zerocopy */
0054 #define VHOST_MAX_PEND 128
0055 #define VHOST_GOODCOPY_LEN 256
0056 
0057 /*
0058  * For transmit, used buffer len is unused; we override it to track buffer
0059  * status internally; used for zerocopy tx only.
0060  */
0061 /* Lower device DMA failed */
0062 #define VHOST_DMA_FAILED_LEN    ((__force __virtio32)3)
0063 /* Lower device DMA done */
0064 #define VHOST_DMA_DONE_LEN  ((__force __virtio32)2)
0065 /* Lower device DMA in progress */
0066 #define VHOST_DMA_IN_PROGRESS   ((__force __virtio32)1)
0067 /* Buffer unused */
0068 #define VHOST_DMA_CLEAR_LEN ((__force __virtio32)0)
0069 
0070 #define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN)
0071 
0072 enum {
0073     VHOST_NET_FEATURES = VHOST_FEATURES |
0074              (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
0075              (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
0076              (1ULL << VIRTIO_F_ACCESS_PLATFORM)
0077 };
0078 
0079 enum {
0080     VHOST_NET_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
0081 };
0082 
0083 enum {
0084     VHOST_NET_VQ_RX = 0,
0085     VHOST_NET_VQ_TX = 1,
0086     VHOST_NET_VQ_MAX = 2,
0087 };
0088 
0089 struct vhost_net_ubuf_ref {
0090     /* refcount follows semantics similar to kref:
0091      *  0: object is released
0092      *  1: no outstanding ubufs
0093      * >1: outstanding ubufs
0094      */
0095     atomic_t refcount;
0096     wait_queue_head_t wait;
0097     struct vhost_virtqueue *vq;
0098 };
0099 
0100 #define VHOST_NET_BATCH 64
0101 struct vhost_net_buf {
0102     void **queue;
0103     int tail;
0104     int head;
0105 };
0106 
0107 struct vhost_net_virtqueue {
0108     struct vhost_virtqueue vq;
0109     size_t vhost_hlen;
0110     size_t sock_hlen;
0111     /* vhost zerocopy support fields below: */
0112     /* last used idx for outstanding DMA zerocopy buffers */
0113     int upend_idx;
0114     /* For TX, first used idx for DMA done zerocopy buffers
0115      * For RX, number of batched heads
0116      */
0117     int done_idx;
0118     /* Number of XDP frames batched */
0119     int batched_xdp;
0120     /* an array of userspace buffers info */
0121     struct ubuf_info *ubuf_info;
0122     /* Reference counting for outstanding ubufs.
0123      * Protected by vq mutex. Writers must also take device mutex. */
0124     struct vhost_net_ubuf_ref *ubufs;
0125     struct ptr_ring *rx_ring;
0126     struct vhost_net_buf rxq;
0127     /* Batched XDP buffs */
0128     struct xdp_buff *xdp;
0129 };
0130 
0131 struct vhost_net {
0132     struct vhost_dev dev;
0133     struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
0134     struct vhost_poll poll[VHOST_NET_VQ_MAX];
0135     /* Number of TX recently submitted.
0136      * Protected by tx vq lock. */
0137     unsigned tx_packets;
0138     /* Number of times zerocopy TX recently failed.
0139      * Protected by tx vq lock. */
0140     unsigned tx_zcopy_err;
0141     /* Flush in progress. Protected by tx vq lock. */
0142     bool tx_flush;
0143     /* Private page frag */
0144     struct page_frag page_frag;
0145     /* Refcount bias of page frag */
0146     int refcnt_bias;
0147 };
0148 
0149 static unsigned vhost_net_zcopy_mask __read_mostly;
0150 
0151 static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq)
0152 {
0153     if (rxq->tail != rxq->head)
0154         return rxq->queue[rxq->head];
0155     else
0156         return NULL;
0157 }
0158 
0159 static int vhost_net_buf_get_size(struct vhost_net_buf *rxq)
0160 {
0161     return rxq->tail - rxq->head;
0162 }
0163 
0164 static int vhost_net_buf_is_empty(struct vhost_net_buf *rxq)
0165 {
0166     return rxq->tail == rxq->head;
0167 }
0168 
0169 static void *vhost_net_buf_consume(struct vhost_net_buf *rxq)
0170 {
0171     void *ret = vhost_net_buf_get_ptr(rxq);
0172     ++rxq->head;
0173     return ret;
0174 }
0175 
0176 static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq)
0177 {
0178     struct vhost_net_buf *rxq = &nvq->rxq;
0179 
0180     rxq->head = 0;
0181     rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue,
0182                           VHOST_NET_BATCH);
0183     return rxq->tail;
0184 }
0185 
0186 static void vhost_net_buf_unproduce(struct vhost_net_virtqueue *nvq)
0187 {
0188     struct vhost_net_buf *rxq = &nvq->rxq;
0189 
0190     if (nvq->rx_ring && !vhost_net_buf_is_empty(rxq)) {
0191         ptr_ring_unconsume(nvq->rx_ring, rxq->queue + rxq->head,
0192                    vhost_net_buf_get_size(rxq),
0193                    tun_ptr_free);
0194         rxq->head = rxq->tail = 0;
0195     }
0196 }
0197 
0198 static int vhost_net_buf_peek_len(void *ptr)
0199 {
0200     if (tun_is_xdp_frame(ptr)) {
0201         struct xdp_frame *xdpf = tun_ptr_to_xdp(ptr);
0202 
0203         return xdpf->len;
0204     }
0205 
0206     return __skb_array_len_with_tag(ptr);
0207 }
0208 
0209 static int vhost_net_buf_peek(struct vhost_net_virtqueue *nvq)
0210 {
0211     struct vhost_net_buf *rxq = &nvq->rxq;
0212 
0213     if (!vhost_net_buf_is_empty(rxq))
0214         goto out;
0215 
0216     if (!vhost_net_buf_produce(nvq))
0217         return 0;
0218 
0219 out:
0220     return vhost_net_buf_peek_len(vhost_net_buf_get_ptr(rxq));
0221 }
0222 
0223 static void vhost_net_buf_init(struct vhost_net_buf *rxq)
0224 {
0225     rxq->head = rxq->tail = 0;
0226 }
0227 
0228 static void vhost_net_enable_zcopy(int vq)
0229 {
0230     vhost_net_zcopy_mask |= 0x1 << vq;
0231 }
0232 
0233 static struct vhost_net_ubuf_ref *
0234 vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
0235 {
0236     struct vhost_net_ubuf_ref *ubufs;
0237     /* No zero copy backend? Nothing to count. */
0238     if (!zcopy)
0239         return NULL;
0240     ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
0241     if (!ubufs)
0242         return ERR_PTR(-ENOMEM);
0243     atomic_set(&ubufs->refcount, 1);
0244     init_waitqueue_head(&ubufs->wait);
0245     ubufs->vq = vq;
0246     return ubufs;
0247 }
0248 
0249 static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
0250 {
0251     int r = atomic_sub_return(1, &ubufs->refcount);
0252     if (unlikely(!r))
0253         wake_up(&ubufs->wait);
0254     return r;
0255 }
0256 
0257 static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
0258 {
0259     vhost_net_ubuf_put(ubufs);
0260     wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
0261 }
0262 
0263 static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
0264 {
0265     vhost_net_ubuf_put_and_wait(ubufs);
0266     kfree(ubufs);
0267 }
0268 
0269 static void vhost_net_clear_ubuf_info(struct vhost_net *n)
0270 {
0271     int i;
0272 
0273     for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
0274         kfree(n->vqs[i].ubuf_info);
0275         n->vqs[i].ubuf_info = NULL;
0276     }
0277 }
0278 
0279 static int vhost_net_set_ubuf_info(struct vhost_net *n)
0280 {
0281     bool zcopy;
0282     int i;
0283 
0284     for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
0285         zcopy = vhost_net_zcopy_mask & (0x1 << i);
0286         if (!zcopy)
0287             continue;
0288         n->vqs[i].ubuf_info =
0289             kmalloc_array(UIO_MAXIOV,
0290                       sizeof(*n->vqs[i].ubuf_info),
0291                       GFP_KERNEL);
0292         if  (!n->vqs[i].ubuf_info)
0293             goto err;
0294     }
0295     return 0;
0296 
0297 err:
0298     vhost_net_clear_ubuf_info(n);
0299     return -ENOMEM;
0300 }
0301 
0302 static void vhost_net_vq_reset(struct vhost_net *n)
0303 {
0304     int i;
0305 
0306     vhost_net_clear_ubuf_info(n);
0307 
0308     for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
0309         n->vqs[i].done_idx = 0;
0310         n->vqs[i].upend_idx = 0;
0311         n->vqs[i].ubufs = NULL;
0312         n->vqs[i].vhost_hlen = 0;
0313         n->vqs[i].sock_hlen = 0;
0314         vhost_net_buf_init(&n->vqs[i].rxq);
0315     }
0316 
0317 }
0318 
0319 static void vhost_net_tx_packet(struct vhost_net *net)
0320 {
0321     ++net->tx_packets;
0322     if (net->tx_packets < 1024)
0323         return;
0324     net->tx_packets = 0;
0325     net->tx_zcopy_err = 0;
0326 }
0327 
0328 static void vhost_net_tx_err(struct vhost_net *net)
0329 {
0330     ++net->tx_zcopy_err;
0331 }
0332 
0333 static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
0334 {
0335     /* TX flush waits for outstanding DMAs to be done.
0336      * Don't start new DMAs.
0337      */
0338     return !net->tx_flush &&
0339         net->tx_packets / 64 >= net->tx_zcopy_err;
0340 }
0341 
0342 static bool vhost_sock_zcopy(struct socket *sock)
0343 {
0344     return unlikely(experimental_zcopytx) &&
0345         sock_flag(sock->sk, SOCK_ZEROCOPY);
0346 }
0347 
0348 static bool vhost_sock_xdp(struct socket *sock)
0349 {
0350     return sock_flag(sock->sk, SOCK_XDP);
0351 }
0352 
0353 /* In case of DMA done not in order in lower device driver for some reason.
0354  * upend_idx is used to track end of used idx, done_idx is used to track head
0355  * of used idx. Once lower device DMA done contiguously, we will signal KVM
0356  * guest used idx.
0357  */
0358 static void vhost_zerocopy_signal_used(struct vhost_net *net,
0359                        struct vhost_virtqueue *vq)
0360 {
0361     struct vhost_net_virtqueue *nvq =
0362         container_of(vq, struct vhost_net_virtqueue, vq);
0363     int i, add;
0364     int j = 0;
0365 
0366     for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
0367         if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
0368             vhost_net_tx_err(net);
0369         if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
0370             vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
0371             ++j;
0372         } else
0373             break;
0374     }
0375     while (j) {
0376         add = min(UIO_MAXIOV - nvq->done_idx, j);
0377         vhost_add_used_and_signal_n(vq->dev, vq,
0378                         &vq->heads[nvq->done_idx], add);
0379         nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
0380         j -= add;
0381     }
0382 }
0383 
0384 static void vhost_zerocopy_callback(struct sk_buff *skb,
0385                     struct ubuf_info *ubuf, bool success)
0386 {
0387     struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
0388     struct vhost_virtqueue *vq = ubufs->vq;
0389     int cnt;
0390 
0391     rcu_read_lock_bh();
0392 
0393     /* set len to mark this desc buffers done DMA */
0394     vq->heads[ubuf->desc].len = success ?
0395         VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
0396     cnt = vhost_net_ubuf_put(ubufs);
0397 
0398     /*
0399      * Trigger polling thread if guest stopped submitting new buffers:
0400      * in this case, the refcount after decrement will eventually reach 1.
0401      * We also trigger polling periodically after each 16 packets
0402      * (the value 16 here is more or less arbitrary, it's tuned to trigger
0403      * less than 10% of times).
0404      */
0405     if (cnt <= 1 || !(cnt % 16))
0406         vhost_poll_queue(&vq->poll);
0407 
0408     rcu_read_unlock_bh();
0409 }
0410 
0411 static inline unsigned long busy_clock(void)
0412 {
0413     return local_clock() >> 10;
0414 }
0415 
0416 static bool vhost_can_busy_poll(unsigned long endtime)
0417 {
0418     return likely(!need_resched() && !time_after(busy_clock(), endtime) &&
0419               !signal_pending(current));
0420 }
0421 
0422 static void vhost_net_disable_vq(struct vhost_net *n,
0423                  struct vhost_virtqueue *vq)
0424 {
0425     struct vhost_net_virtqueue *nvq =
0426         container_of(vq, struct vhost_net_virtqueue, vq);
0427     struct vhost_poll *poll = n->poll + (nvq - n->vqs);
0428     if (!vhost_vq_get_backend(vq))
0429         return;
0430     vhost_poll_stop(poll);
0431 }
0432 
0433 static int vhost_net_enable_vq(struct vhost_net *n,
0434                 struct vhost_virtqueue *vq)
0435 {
0436     struct vhost_net_virtqueue *nvq =
0437         container_of(vq, struct vhost_net_virtqueue, vq);
0438     struct vhost_poll *poll = n->poll + (nvq - n->vqs);
0439     struct socket *sock;
0440 
0441     sock = vhost_vq_get_backend(vq);
0442     if (!sock)
0443         return 0;
0444 
0445     return vhost_poll_start(poll, sock->file);
0446 }
0447 
0448 static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq)
0449 {
0450     struct vhost_virtqueue *vq = &nvq->vq;
0451     struct vhost_dev *dev = vq->dev;
0452 
0453     if (!nvq->done_idx)
0454         return;
0455 
0456     vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx);
0457     nvq->done_idx = 0;
0458 }
0459 
0460 static void vhost_tx_batch(struct vhost_net *net,
0461                struct vhost_net_virtqueue *nvq,
0462                struct socket *sock,
0463                struct msghdr *msghdr)
0464 {
0465     struct tun_msg_ctl ctl = {
0466         .type = TUN_MSG_PTR,
0467         .num = nvq->batched_xdp,
0468         .ptr = nvq->xdp,
0469     };
0470     int i, err;
0471 
0472     if (nvq->batched_xdp == 0)
0473         goto signal_used;
0474 
0475     msghdr->msg_control = &ctl;
0476     msghdr->msg_controllen = sizeof(ctl);
0477     err = sock->ops->sendmsg(sock, msghdr, 0);
0478     if (unlikely(err < 0)) {
0479         vq_err(&nvq->vq, "Fail to batch sending packets\n");
0480 
0481         /* free pages owned by XDP; since this is an unlikely error path,
0482          * keep it simple and avoid more complex bulk update for the
0483          * used pages
0484          */
0485         for (i = 0; i < nvq->batched_xdp; ++i)
0486             put_page(virt_to_head_page(nvq->xdp[i].data));
0487         nvq->batched_xdp = 0;
0488         nvq->done_idx = 0;
0489         return;
0490     }
0491 
0492 signal_used:
0493     vhost_net_signal_used(nvq);
0494     nvq->batched_xdp = 0;
0495 }
0496 
0497 static int sock_has_rx_data(struct socket *sock)
0498 {
0499     if (unlikely(!sock))
0500         return 0;
0501 
0502     if (sock->ops->peek_len)
0503         return sock->ops->peek_len(sock);
0504 
0505     return skb_queue_empty(&sock->sk->sk_receive_queue);
0506 }
0507 
0508 static void vhost_net_busy_poll_try_queue(struct vhost_net *net,
0509                       struct vhost_virtqueue *vq)
0510 {
0511     if (!vhost_vq_avail_empty(&net->dev, vq)) {
0512         vhost_poll_queue(&vq->poll);
0513     } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
0514         vhost_disable_notify(&net->dev, vq);
0515         vhost_poll_queue(&vq->poll);
0516     }
0517 }
0518 
0519 static void vhost_net_busy_poll(struct vhost_net *net,
0520                 struct vhost_virtqueue *rvq,
0521                 struct vhost_virtqueue *tvq,
0522                 bool *busyloop_intr,
0523                 bool poll_rx)
0524 {
0525     unsigned long busyloop_timeout;
0526     unsigned long endtime;
0527     struct socket *sock;
0528     struct vhost_virtqueue *vq = poll_rx ? tvq : rvq;
0529 
0530     /* Try to hold the vq mutex of the paired virtqueue. We can't
0531      * use mutex_lock() here since we could not guarantee a
0532      * consistenet lock ordering.
0533      */
0534     if (!mutex_trylock(&vq->mutex))
0535         return;
0536 
0537     vhost_disable_notify(&net->dev, vq);
0538     sock = vhost_vq_get_backend(rvq);
0539 
0540     busyloop_timeout = poll_rx ? rvq->busyloop_timeout:
0541                      tvq->busyloop_timeout;
0542 
0543     preempt_disable();
0544     endtime = busy_clock() + busyloop_timeout;
0545 
0546     while (vhost_can_busy_poll(endtime)) {
0547         if (vhost_has_work(&net->dev)) {
0548             *busyloop_intr = true;
0549             break;
0550         }
0551 
0552         if ((sock_has_rx_data(sock) &&
0553              !vhost_vq_avail_empty(&net->dev, rvq)) ||
0554             !vhost_vq_avail_empty(&net->dev, tvq))
0555             break;
0556 
0557         cpu_relax();
0558     }
0559 
0560     preempt_enable();
0561 
0562     if (poll_rx || sock_has_rx_data(sock))
0563         vhost_net_busy_poll_try_queue(net, vq);
0564     else if (!poll_rx) /* On tx here, sock has no rx data. */
0565         vhost_enable_notify(&net->dev, rvq);
0566 
0567     mutex_unlock(&vq->mutex);
0568 }
0569 
0570 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
0571                     struct vhost_net_virtqueue *tnvq,
0572                     unsigned int *out_num, unsigned int *in_num,
0573                     struct msghdr *msghdr, bool *busyloop_intr)
0574 {
0575     struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
0576     struct vhost_virtqueue *rvq = &rnvq->vq;
0577     struct vhost_virtqueue *tvq = &tnvq->vq;
0578 
0579     int r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
0580                   out_num, in_num, NULL, NULL);
0581 
0582     if (r == tvq->num && tvq->busyloop_timeout) {
0583         /* Flush batched packets first */
0584         if (!vhost_sock_zcopy(vhost_vq_get_backend(tvq)))
0585             vhost_tx_batch(net, tnvq,
0586                        vhost_vq_get_backend(tvq),
0587                        msghdr);
0588 
0589         vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, false);
0590 
0591         r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
0592                       out_num, in_num, NULL, NULL);
0593     }
0594 
0595     return r;
0596 }
0597 
0598 static bool vhost_exceeds_maxpend(struct vhost_net *net)
0599 {
0600     struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
0601     struct vhost_virtqueue *vq = &nvq->vq;
0602 
0603     return (nvq->upend_idx + UIO_MAXIOV - nvq->done_idx) % UIO_MAXIOV >
0604            min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2);
0605 }
0606 
0607 static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter,
0608                 size_t hdr_size, int out)
0609 {
0610     /* Skip header. TODO: support TSO. */
0611     size_t len = iov_length(vq->iov, out);
0612 
0613     iov_iter_init(iter, WRITE, vq->iov, out, len);
0614     iov_iter_advance(iter, hdr_size);
0615 
0616     return iov_iter_count(iter);
0617 }
0618 
0619 static int get_tx_bufs(struct vhost_net *net,
0620                struct vhost_net_virtqueue *nvq,
0621                struct msghdr *msg,
0622                unsigned int *out, unsigned int *in,
0623                size_t *len, bool *busyloop_intr)
0624 {
0625     struct vhost_virtqueue *vq = &nvq->vq;
0626     int ret;
0627 
0628     ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr);
0629 
0630     if (ret < 0 || ret == vq->num)
0631         return ret;
0632 
0633     if (*in) {
0634         vq_err(vq, "Unexpected descriptor format for TX: out %d, int %d\n",
0635             *out, *in);
0636         return -EFAULT;
0637     }
0638 
0639     /* Sanity check */
0640     *len = init_iov_iter(vq, &msg->msg_iter, nvq->vhost_hlen, *out);
0641     if (*len == 0) {
0642         vq_err(vq, "Unexpected header len for TX: %zd expected %zd\n",
0643             *len, nvq->vhost_hlen);
0644         return -EFAULT;
0645     }
0646 
0647     return ret;
0648 }
0649 
0650 static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len)
0651 {
0652     return total_len < VHOST_NET_WEIGHT &&
0653            !vhost_vq_avail_empty(vq->dev, vq);
0654 }
0655 
0656 static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz,
0657                        struct page_frag *pfrag, gfp_t gfp)
0658 {
0659     if (pfrag->page) {
0660         if (pfrag->offset + sz <= pfrag->size)
0661             return true;
0662         __page_frag_cache_drain(pfrag->page, net->refcnt_bias);
0663     }
0664 
0665     pfrag->offset = 0;
0666     net->refcnt_bias = 0;
0667     if (SKB_FRAG_PAGE_ORDER) {
0668         /* Avoid direct reclaim but allow kswapd to wake */
0669         pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) |
0670                       __GFP_COMP | __GFP_NOWARN |
0671                       __GFP_NORETRY,
0672                       SKB_FRAG_PAGE_ORDER);
0673         if (likely(pfrag->page)) {
0674             pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER;
0675             goto done;
0676         }
0677     }
0678     pfrag->page = alloc_page(gfp);
0679     if (likely(pfrag->page)) {
0680         pfrag->size = PAGE_SIZE;
0681         goto done;
0682     }
0683     return false;
0684 
0685 done:
0686     net->refcnt_bias = USHRT_MAX;
0687     page_ref_add(pfrag->page, USHRT_MAX - 1);
0688     return true;
0689 }
0690 
0691 #define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
0692 
0693 static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
0694                    struct iov_iter *from)
0695 {
0696     struct vhost_virtqueue *vq = &nvq->vq;
0697     struct vhost_net *net = container_of(vq->dev, struct vhost_net,
0698                          dev);
0699     struct socket *sock = vhost_vq_get_backend(vq);
0700     struct page_frag *alloc_frag = &net->page_frag;
0701     struct virtio_net_hdr *gso;
0702     struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
0703     struct tun_xdp_hdr *hdr;
0704     size_t len = iov_iter_count(from);
0705     int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0;
0706     int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
0707     int pad = SKB_DATA_ALIGN(VHOST_NET_RX_PAD + headroom + nvq->sock_hlen);
0708     int sock_hlen = nvq->sock_hlen;
0709     void *buf;
0710     int copied;
0711 
0712     if (unlikely(len < nvq->sock_hlen))
0713         return -EFAULT;
0714 
0715     if (SKB_DATA_ALIGN(len + pad) +
0716         SKB_DATA_ALIGN(sizeof(struct skb_shared_info)) > PAGE_SIZE)
0717         return -ENOSPC;
0718 
0719     buflen += SKB_DATA_ALIGN(len + pad);
0720     alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
0721     if (unlikely(!vhost_net_page_frag_refill(net, buflen,
0722                          alloc_frag, GFP_KERNEL)))
0723         return -ENOMEM;
0724 
0725     buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
0726     copied = copy_page_from_iter(alloc_frag->page,
0727                      alloc_frag->offset +
0728                      offsetof(struct tun_xdp_hdr, gso),
0729                      sock_hlen, from);
0730     if (copied != sock_hlen)
0731         return -EFAULT;
0732 
0733     hdr = buf;
0734     gso = &hdr->gso;
0735 
0736     if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
0737         vhost16_to_cpu(vq, gso->csum_start) +
0738         vhost16_to_cpu(vq, gso->csum_offset) + 2 >
0739         vhost16_to_cpu(vq, gso->hdr_len)) {
0740         gso->hdr_len = cpu_to_vhost16(vq,
0741                    vhost16_to_cpu(vq, gso->csum_start) +
0742                    vhost16_to_cpu(vq, gso->csum_offset) + 2);
0743 
0744         if (vhost16_to_cpu(vq, gso->hdr_len) > len)
0745             return -EINVAL;
0746     }
0747 
0748     len -= sock_hlen;
0749     copied = copy_page_from_iter(alloc_frag->page,
0750                      alloc_frag->offset + pad,
0751                      len, from);
0752     if (copied != len)
0753         return -EFAULT;
0754 
0755     xdp_init_buff(xdp, buflen, NULL);
0756     xdp_prepare_buff(xdp, buf, pad, len, true);
0757     hdr->buflen = buflen;
0758 
0759     --net->refcnt_bias;
0760     alloc_frag->offset += buflen;
0761 
0762     ++nvq->batched_xdp;
0763 
0764     return 0;
0765 }
0766 
0767 static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
0768 {
0769     struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
0770     struct vhost_virtqueue *vq = &nvq->vq;
0771     unsigned out, in;
0772     int head;
0773     struct msghdr msg = {
0774         .msg_name = NULL,
0775         .msg_namelen = 0,
0776         .msg_control = NULL,
0777         .msg_controllen = 0,
0778         .msg_flags = MSG_DONTWAIT,
0779     };
0780     size_t len, total_len = 0;
0781     int err;
0782     int sent_pkts = 0;
0783     bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX);
0784 
0785     do {
0786         bool busyloop_intr = false;
0787 
0788         if (nvq->done_idx == VHOST_NET_BATCH)
0789             vhost_tx_batch(net, nvq, sock, &msg);
0790 
0791         head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
0792                    &busyloop_intr);
0793         /* On error, stop handling until the next kick. */
0794         if (unlikely(head < 0))
0795             break;
0796         /* Nothing new?  Wait for eventfd to tell us they refilled. */
0797         if (head == vq->num) {
0798             if (unlikely(busyloop_intr)) {
0799                 vhost_poll_queue(&vq->poll);
0800             } else if (unlikely(vhost_enable_notify(&net->dev,
0801                                 vq))) {
0802                 vhost_disable_notify(&net->dev, vq);
0803                 continue;
0804             }
0805             break;
0806         }
0807 
0808         total_len += len;
0809 
0810         /* For simplicity, TX batching is only enabled if
0811          * sndbuf is unlimited.
0812          */
0813         if (sock_can_batch) {
0814             err = vhost_net_build_xdp(nvq, &msg.msg_iter);
0815             if (!err) {
0816                 goto done;
0817             } else if (unlikely(err != -ENOSPC)) {
0818                 vhost_tx_batch(net, nvq, sock, &msg);
0819                 vhost_discard_vq_desc(vq, 1);
0820                 vhost_net_enable_vq(net, vq);
0821                 break;
0822             }
0823 
0824             /* We can't build XDP buff, go for single
0825              * packet path but let's flush batched
0826              * packets.
0827              */
0828             vhost_tx_batch(net, nvq, sock, &msg);
0829             msg.msg_control = NULL;
0830         } else {
0831             if (tx_can_batch(vq, total_len))
0832                 msg.msg_flags |= MSG_MORE;
0833             else
0834                 msg.msg_flags &= ~MSG_MORE;
0835         }
0836 
0837         err = sock->ops->sendmsg(sock, &msg, len);
0838         if (unlikely(err < 0)) {
0839             if (err == -EAGAIN || err == -ENOMEM || err == -ENOBUFS) {
0840                 vhost_discard_vq_desc(vq, 1);
0841                 vhost_net_enable_vq(net, vq);
0842                 break;
0843             }
0844             pr_debug("Fail to send packet: err %d", err);
0845         } else if (unlikely(err != len))
0846             pr_debug("Truncated TX packet: len %d != %zd\n",
0847                  err, len);
0848 done:
0849         vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
0850         vq->heads[nvq->done_idx].len = 0;
0851         ++nvq->done_idx;
0852     } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
0853 
0854     vhost_tx_batch(net, nvq, sock, &msg);
0855 }
0856 
0857 static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
0858 {
0859     struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
0860     struct vhost_virtqueue *vq = &nvq->vq;
0861     unsigned out, in;
0862     int head;
0863     struct msghdr msg = {
0864         .msg_name = NULL,
0865         .msg_namelen = 0,
0866         .msg_control = NULL,
0867         .msg_controllen = 0,
0868         .msg_flags = MSG_DONTWAIT,
0869     };
0870     struct tun_msg_ctl ctl;
0871     size_t len, total_len = 0;
0872     int err;
0873     struct vhost_net_ubuf_ref *ubufs;
0874     struct ubuf_info *ubuf;
0875     bool zcopy_used;
0876     int sent_pkts = 0;
0877 
0878     do {
0879         bool busyloop_intr;
0880 
0881         /* Release DMAs done buffers first */
0882         vhost_zerocopy_signal_used(net, vq);
0883 
0884         busyloop_intr = false;
0885         head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
0886                    &busyloop_intr);
0887         /* On error, stop handling until the next kick. */
0888         if (unlikely(head < 0))
0889             break;
0890         /* Nothing new?  Wait for eventfd to tell us they refilled. */
0891         if (head == vq->num) {
0892             if (unlikely(busyloop_intr)) {
0893                 vhost_poll_queue(&vq->poll);
0894             } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
0895                 vhost_disable_notify(&net->dev, vq);
0896                 continue;
0897             }
0898             break;
0899         }
0900 
0901         zcopy_used = len >= VHOST_GOODCOPY_LEN
0902                  && !vhost_exceeds_maxpend(net)
0903                  && vhost_net_tx_select_zcopy(net);
0904 
0905         /* use msg_control to pass vhost zerocopy ubuf info to skb */
0906         if (zcopy_used) {
0907             ubuf = nvq->ubuf_info + nvq->upend_idx;
0908             vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head);
0909             vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
0910             ubuf->callback = vhost_zerocopy_callback;
0911             ubuf->ctx = nvq->ubufs;
0912             ubuf->desc = nvq->upend_idx;
0913             ubuf->flags = SKBFL_ZEROCOPY_FRAG;
0914             refcount_set(&ubuf->refcnt, 1);
0915             msg.msg_control = &ctl;
0916             ctl.type = TUN_MSG_UBUF;
0917             ctl.ptr = ubuf;
0918             msg.msg_controllen = sizeof(ctl);
0919             ubufs = nvq->ubufs;
0920             atomic_inc(&ubufs->refcount);
0921             nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
0922         } else {
0923             msg.msg_control = NULL;
0924             ubufs = NULL;
0925         }
0926         total_len += len;
0927         if (tx_can_batch(vq, total_len) &&
0928             likely(!vhost_exceeds_maxpend(net))) {
0929             msg.msg_flags |= MSG_MORE;
0930         } else {
0931             msg.msg_flags &= ~MSG_MORE;
0932         }
0933 
0934         err = sock->ops->sendmsg(sock, &msg, len);
0935         if (unlikely(err < 0)) {
0936             if (zcopy_used) {
0937                 if (vq->heads[ubuf->desc].len == VHOST_DMA_IN_PROGRESS)
0938                     vhost_net_ubuf_put(ubufs);
0939                 nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
0940                     % UIO_MAXIOV;
0941             }
0942             if (err == -EAGAIN || err == -ENOMEM || err == -ENOBUFS) {
0943                 vhost_discard_vq_desc(vq, 1);
0944                 vhost_net_enable_vq(net, vq);
0945                 break;
0946             }
0947             pr_debug("Fail to send packet: err %d", err);
0948         } else if (unlikely(err != len))
0949             pr_debug("Truncated TX packet: "
0950                  " len %d != %zd\n", err, len);
0951         if (!zcopy_used)
0952             vhost_add_used_and_signal(&net->dev, vq, head, 0);
0953         else
0954             vhost_zerocopy_signal_used(net, vq);
0955         vhost_net_tx_packet(net);
0956     } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
0957 }
0958 
0959 /* Expects to be always run from workqueue - which acts as
0960  * read-size critical section for our kind of RCU. */
0961 static void handle_tx(struct vhost_net *net)
0962 {
0963     struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
0964     struct vhost_virtqueue *vq = &nvq->vq;
0965     struct socket *sock;
0966 
0967     mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_TX);
0968     sock = vhost_vq_get_backend(vq);
0969     if (!sock)
0970         goto out;
0971 
0972     if (!vq_meta_prefetch(vq))
0973         goto out;
0974 
0975     vhost_disable_notify(&net->dev, vq);
0976     vhost_net_disable_vq(net, vq);
0977 
0978     if (vhost_sock_zcopy(sock))
0979         handle_tx_zerocopy(net, sock);
0980     else
0981         handle_tx_copy(net, sock);
0982 
0983 out:
0984     mutex_unlock(&vq->mutex);
0985 }
0986 
0987 static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk)
0988 {
0989     struct sk_buff *head;
0990     int len = 0;
0991     unsigned long flags;
0992 
0993     if (rvq->rx_ring)
0994         return vhost_net_buf_peek(rvq);
0995 
0996     spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
0997     head = skb_peek(&sk->sk_receive_queue);
0998     if (likely(head)) {
0999         len = head->len;
1000         if (skb_vlan_tag_present(head))
1001             len += VLAN_HLEN;
1002     }
1003 
1004     spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
1005     return len;
1006 }
1007 
1008 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk,
1009                       bool *busyloop_intr)
1010 {
1011     struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
1012     struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX];
1013     struct vhost_virtqueue *rvq = &rnvq->vq;
1014     struct vhost_virtqueue *tvq = &tnvq->vq;
1015     int len = peek_head_len(rnvq, sk);
1016 
1017     if (!len && rvq->busyloop_timeout) {
1018         /* Flush batched heads first */
1019         vhost_net_signal_used(rnvq);
1020         /* Both tx vq and rx socket were polled here */
1021         vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, true);
1022 
1023         len = peek_head_len(rnvq, sk);
1024     }
1025 
1026     return len;
1027 }
1028 
1029 /* This is a multi-buffer version of vhost_get_desc, that works if
1030  *  vq has read descriptors only.
1031  * @vq      - the relevant virtqueue
1032  * @datalen - data length we'll be reading
1033  * @iovcount    - returned count of io vectors we fill
1034  * @log     - vhost log
1035  * @log_num - log offset
1036  * @quota       - headcount quota, 1 for big buffer
1037  *  returns number of buffer heads allocated, negative on error
1038  */
1039 static int get_rx_bufs(struct vhost_virtqueue *vq,
1040                struct vring_used_elem *heads,
1041                int datalen,
1042                unsigned *iovcount,
1043                struct vhost_log *log,
1044                unsigned *log_num,
1045                unsigned int quota)
1046 {
1047     unsigned int out, in;
1048     int seg = 0;
1049     int headcount = 0;
1050     unsigned d;
1051     int r, nlogs = 0;
1052     /* len is always initialized before use since we are always called with
1053      * datalen > 0.
1054      */
1055     u32 len;
1056 
1057     while (datalen > 0 && headcount < quota) {
1058         if (unlikely(seg >= UIO_MAXIOV)) {
1059             r = -ENOBUFS;
1060             goto err;
1061         }
1062         r = vhost_get_vq_desc(vq, vq->iov + seg,
1063                       ARRAY_SIZE(vq->iov) - seg, &out,
1064                       &in, log, log_num);
1065         if (unlikely(r < 0))
1066             goto err;
1067 
1068         d = r;
1069         if (d == vq->num) {
1070             r = 0;
1071             goto err;
1072         }
1073         if (unlikely(out || in <= 0)) {
1074             vq_err(vq, "unexpected descriptor format for RX: "
1075                 "out %d, in %d\n", out, in);
1076             r = -EINVAL;
1077             goto err;
1078         }
1079         if (unlikely(log)) {
1080             nlogs += *log_num;
1081             log += *log_num;
1082         }
1083         heads[headcount].id = cpu_to_vhost32(vq, d);
1084         len = iov_length(vq->iov + seg, in);
1085         heads[headcount].len = cpu_to_vhost32(vq, len);
1086         datalen -= len;
1087         ++headcount;
1088         seg += in;
1089     }
1090     heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen);
1091     *iovcount = seg;
1092     if (unlikely(log))
1093         *log_num = nlogs;
1094 
1095     /* Detect overrun */
1096     if (unlikely(datalen > 0)) {
1097         r = UIO_MAXIOV + 1;
1098         goto err;
1099     }
1100     return headcount;
1101 err:
1102     vhost_discard_vq_desc(vq, headcount);
1103     return r;
1104 }
1105 
1106 /* Expects to be always run from workqueue - which acts as
1107  * read-size critical section for our kind of RCU. */
1108 static void handle_rx(struct vhost_net *net)
1109 {
1110     struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
1111     struct vhost_virtqueue *vq = &nvq->vq;
1112     unsigned in, log;
1113     struct vhost_log *vq_log;
1114     struct msghdr msg = {
1115         .msg_name = NULL,
1116         .msg_namelen = 0,
1117         .msg_control = NULL, /* FIXME: get and handle RX aux data. */
1118         .msg_controllen = 0,
1119         .msg_flags = MSG_DONTWAIT,
1120     };
1121     struct virtio_net_hdr hdr = {
1122         .flags = 0,
1123         .gso_type = VIRTIO_NET_HDR_GSO_NONE
1124     };
1125     size_t total_len = 0;
1126     int err, mergeable;
1127     s16 headcount;
1128     size_t vhost_hlen, sock_hlen;
1129     size_t vhost_len, sock_len;
1130     bool busyloop_intr = false;
1131     struct socket *sock;
1132     struct iov_iter fixup;
1133     __virtio16 num_buffers;
1134     int recv_pkts = 0;
1135 
1136     mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_RX);
1137     sock = vhost_vq_get_backend(vq);
1138     if (!sock)
1139         goto out;
1140 
1141     if (!vq_meta_prefetch(vq))
1142         goto out;
1143 
1144     vhost_disable_notify(&net->dev, vq);
1145     vhost_net_disable_vq(net, vq);
1146 
1147     vhost_hlen = nvq->vhost_hlen;
1148     sock_hlen = nvq->sock_hlen;
1149 
1150     vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
1151         vq->log : NULL;
1152     mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
1153 
1154     do {
1155         sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
1156                               &busyloop_intr);
1157         if (!sock_len)
1158             break;
1159         sock_len += sock_hlen;
1160         vhost_len = sock_len + vhost_hlen;
1161         headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
1162                     vhost_len, &in, vq_log, &log,
1163                     likely(mergeable) ? UIO_MAXIOV : 1);
1164         /* On error, stop handling until the next kick. */
1165         if (unlikely(headcount < 0))
1166             goto out;
1167         /* OK, now we need to know about added descriptors. */
1168         if (!headcount) {
1169             if (unlikely(busyloop_intr)) {
1170                 vhost_poll_queue(&vq->poll);
1171             } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
1172                 /* They have slipped one in as we were
1173                  * doing that: check again. */
1174                 vhost_disable_notify(&net->dev, vq);
1175                 continue;
1176             }
1177             /* Nothing new?  Wait for eventfd to tell us
1178              * they refilled. */
1179             goto out;
1180         }
1181         busyloop_intr = false;
1182         if (nvq->rx_ring)
1183             msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
1184         /* On overrun, truncate and discard */
1185         if (unlikely(headcount > UIO_MAXIOV)) {
1186             iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
1187             err = sock->ops->recvmsg(sock, &msg,
1188                          1, MSG_DONTWAIT | MSG_TRUNC);
1189             pr_debug("Discarded rx packet: len %zd\n", sock_len);
1190             continue;
1191         }
1192         /* We don't need to be notified again. */
1193         iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
1194         fixup = msg.msg_iter;
1195         if (unlikely((vhost_hlen))) {
1196             /* We will supply the header ourselves
1197              * TODO: support TSO.
1198              */
1199             iov_iter_advance(&msg.msg_iter, vhost_hlen);
1200         }
1201         err = sock->ops->recvmsg(sock, &msg,
1202                      sock_len, MSG_DONTWAIT | MSG_TRUNC);
1203         /* Userspace might have consumed the packet meanwhile:
1204          * it's not supposed to do this usually, but might be hard
1205          * to prevent. Discard data we got (if any) and keep going. */
1206         if (unlikely(err != sock_len)) {
1207             pr_debug("Discarded rx packet: "
1208                  " len %d, expected %zd\n", err, sock_len);
1209             vhost_discard_vq_desc(vq, headcount);
1210             continue;
1211         }
1212         /* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */
1213         if (unlikely(vhost_hlen)) {
1214             if (copy_to_iter(&hdr, sizeof(hdr),
1215                      &fixup) != sizeof(hdr)) {
1216                 vq_err(vq, "Unable to write vnet_hdr "
1217                        "at addr %p\n", vq->iov->iov_base);
1218                 goto out;
1219             }
1220         } else {
1221             /* Header came from socket; we'll need to patch
1222              * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
1223              */
1224             iov_iter_advance(&fixup, sizeof(hdr));
1225         }
1226         /* TODO: Should check and handle checksum. */
1227 
1228         num_buffers = cpu_to_vhost16(vq, headcount);
1229         if (likely(mergeable) &&
1230             copy_to_iter(&num_buffers, sizeof num_buffers,
1231                  &fixup) != sizeof num_buffers) {
1232             vq_err(vq, "Failed num_buffers write");
1233             vhost_discard_vq_desc(vq, headcount);
1234             goto out;
1235         }
1236         nvq->done_idx += headcount;
1237         if (nvq->done_idx > VHOST_NET_BATCH)
1238             vhost_net_signal_used(nvq);
1239         if (unlikely(vq_log))
1240             vhost_log_write(vq, vq_log, log, vhost_len,
1241                     vq->iov, in);
1242         total_len += vhost_len;
1243     } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len)));
1244 
1245     if (unlikely(busyloop_intr))
1246         vhost_poll_queue(&vq->poll);
1247     else if (!sock_len)
1248         vhost_net_enable_vq(net, vq);
1249 out:
1250     vhost_net_signal_used(nvq);
1251     mutex_unlock(&vq->mutex);
1252 }
1253 
1254 static void handle_tx_kick(struct vhost_work *work)
1255 {
1256     struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
1257                           poll.work);
1258     struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
1259 
1260     handle_tx(net);
1261 }
1262 
1263 static void handle_rx_kick(struct vhost_work *work)
1264 {
1265     struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
1266                           poll.work);
1267     struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
1268 
1269     handle_rx(net);
1270 }
1271 
1272 static void handle_tx_net(struct vhost_work *work)
1273 {
1274     struct vhost_net *net = container_of(work, struct vhost_net,
1275                          poll[VHOST_NET_VQ_TX].work);
1276     handle_tx(net);
1277 }
1278 
1279 static void handle_rx_net(struct vhost_work *work)
1280 {
1281     struct vhost_net *net = container_of(work, struct vhost_net,
1282                          poll[VHOST_NET_VQ_RX].work);
1283     handle_rx(net);
1284 }
1285 
1286 static int vhost_net_open(struct inode *inode, struct file *f)
1287 {
1288     struct vhost_net *n;
1289     struct vhost_dev *dev;
1290     struct vhost_virtqueue **vqs;
1291     void **queue;
1292     struct xdp_buff *xdp;
1293     int i;
1294 
1295     n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1296     if (!n)
1297         return -ENOMEM;
1298     vqs = kmalloc_array(VHOST_NET_VQ_MAX, sizeof(*vqs), GFP_KERNEL);
1299     if (!vqs) {
1300         kvfree(n);
1301         return -ENOMEM;
1302     }
1303 
1304     queue = kmalloc_array(VHOST_NET_BATCH, sizeof(void *),
1305                   GFP_KERNEL);
1306     if (!queue) {
1307         kfree(vqs);
1308         kvfree(n);
1309         return -ENOMEM;
1310     }
1311     n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
1312 
1313     xdp = kmalloc_array(VHOST_NET_BATCH, sizeof(*xdp), GFP_KERNEL);
1314     if (!xdp) {
1315         kfree(vqs);
1316         kvfree(n);
1317         kfree(queue);
1318         return -ENOMEM;
1319     }
1320     n->vqs[VHOST_NET_VQ_TX].xdp = xdp;
1321 
1322     dev = &n->dev;
1323     vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
1324     vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
1325     n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
1326     n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
1327     for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
1328         n->vqs[i].ubufs = NULL;
1329         n->vqs[i].ubuf_info = NULL;
1330         n->vqs[i].upend_idx = 0;
1331         n->vqs[i].done_idx = 0;
1332         n->vqs[i].batched_xdp = 0;
1333         n->vqs[i].vhost_hlen = 0;
1334         n->vqs[i].sock_hlen = 0;
1335         n->vqs[i].rx_ring = NULL;
1336         vhost_net_buf_init(&n->vqs[i].rxq);
1337     }
1338     vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
1339                UIO_MAXIOV + VHOST_NET_BATCH,
1340                VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT, true,
1341                NULL);
1342 
1343     vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
1344     vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
1345 
1346     f->private_data = n;
1347     n->page_frag.page = NULL;
1348     n->refcnt_bias = 0;
1349 
1350     return 0;
1351 }
1352 
1353 static struct socket *vhost_net_stop_vq(struct vhost_net *n,
1354                     struct vhost_virtqueue *vq)
1355 {
1356     struct socket *sock;
1357     struct vhost_net_virtqueue *nvq =
1358         container_of(vq, struct vhost_net_virtqueue, vq);
1359 
1360     mutex_lock(&vq->mutex);
1361     sock = vhost_vq_get_backend(vq);
1362     vhost_net_disable_vq(n, vq);
1363     vhost_vq_set_backend(vq, NULL);
1364     vhost_net_buf_unproduce(nvq);
1365     nvq->rx_ring = NULL;
1366     mutex_unlock(&vq->mutex);
1367     return sock;
1368 }
1369 
1370 static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
1371                struct socket **rx_sock)
1372 {
1373     *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
1374     *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
1375 }
1376 
1377 static void vhost_net_flush(struct vhost_net *n)
1378 {
1379     vhost_dev_flush(&n->dev);
1380     if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
1381         mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1382         n->tx_flush = true;
1383         mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1384         /* Wait for all lower device DMAs done. */
1385         vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
1386         mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1387         n->tx_flush = false;
1388         atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
1389         mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
1390     }
1391 }
1392 
1393 static int vhost_net_release(struct inode *inode, struct file *f)
1394 {
1395     struct vhost_net *n = f->private_data;
1396     struct socket *tx_sock;
1397     struct socket *rx_sock;
1398 
1399     vhost_net_stop(n, &tx_sock, &rx_sock);
1400     vhost_net_flush(n);
1401     vhost_dev_stop(&n->dev);
1402     vhost_dev_cleanup(&n->dev);
1403     vhost_net_vq_reset(n);
1404     if (tx_sock)
1405         sockfd_put(tx_sock);
1406     if (rx_sock)
1407         sockfd_put(rx_sock);
1408     /* Make sure no callbacks are outstanding */
1409     synchronize_rcu();
1410     /* We do an extra flush before freeing memory,
1411      * since jobs can re-queue themselves. */
1412     vhost_net_flush(n);
1413     kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
1414     kfree(n->vqs[VHOST_NET_VQ_TX].xdp);
1415     kfree(n->dev.vqs);
1416     if (n->page_frag.page)
1417         __page_frag_cache_drain(n->page_frag.page, n->refcnt_bias);
1418     kvfree(n);
1419     return 0;
1420 }
1421 
1422 static struct socket *get_raw_socket(int fd)
1423 {
1424     int r;
1425     struct socket *sock = sockfd_lookup(fd, &r);
1426 
1427     if (!sock)
1428         return ERR_PTR(-ENOTSOCK);
1429 
1430     /* Parameter checking */
1431     if (sock->sk->sk_type != SOCK_RAW) {
1432         r = -ESOCKTNOSUPPORT;
1433         goto err;
1434     }
1435 
1436     if (sock->sk->sk_family != AF_PACKET) {
1437         r = -EPFNOSUPPORT;
1438         goto err;
1439     }
1440     return sock;
1441 err:
1442     sockfd_put(sock);
1443     return ERR_PTR(r);
1444 }
1445 
1446 static struct ptr_ring *get_tap_ptr_ring(struct file *file)
1447 {
1448     struct ptr_ring *ring;
1449     ring = tun_get_tx_ring(file);
1450     if (!IS_ERR(ring))
1451         goto out;
1452     ring = tap_get_ptr_ring(file);
1453     if (!IS_ERR(ring))
1454         goto out;
1455     ring = NULL;
1456 out:
1457     return ring;
1458 }
1459 
1460 static struct socket *get_tap_socket(int fd)
1461 {
1462     struct file *file = fget(fd);
1463     struct socket *sock;
1464 
1465     if (!file)
1466         return ERR_PTR(-EBADF);
1467     sock = tun_get_socket(file);
1468     if (!IS_ERR(sock))
1469         return sock;
1470     sock = tap_get_socket(file);
1471     if (IS_ERR(sock))
1472         fput(file);
1473     return sock;
1474 }
1475 
1476 static struct socket *get_socket(int fd)
1477 {
1478     struct socket *sock;
1479 
1480     /* special case to disable backend */
1481     if (fd == -1)
1482         return NULL;
1483     sock = get_raw_socket(fd);
1484     if (!IS_ERR(sock))
1485         return sock;
1486     sock = get_tap_socket(fd);
1487     if (!IS_ERR(sock))
1488         return sock;
1489     return ERR_PTR(-ENOTSOCK);
1490 }
1491 
1492 static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
1493 {
1494     struct socket *sock, *oldsock;
1495     struct vhost_virtqueue *vq;
1496     struct vhost_net_virtqueue *nvq;
1497     struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL;
1498     int r;
1499 
1500     mutex_lock(&n->dev.mutex);
1501     r = vhost_dev_check_owner(&n->dev);
1502     if (r)
1503         goto err;
1504 
1505     if (index >= VHOST_NET_VQ_MAX) {
1506         r = -ENOBUFS;
1507         goto err;
1508     }
1509     vq = &n->vqs[index].vq;
1510     nvq = &n->vqs[index];
1511     mutex_lock(&vq->mutex);
1512 
1513     /* Verify that ring has been setup correctly. */
1514     if (!vhost_vq_access_ok(vq)) {
1515         r = -EFAULT;
1516         goto err_vq;
1517     }
1518     sock = get_socket(fd);
1519     if (IS_ERR(sock)) {
1520         r = PTR_ERR(sock);
1521         goto err_vq;
1522     }
1523 
1524     /* start polling new socket */
1525     oldsock = vhost_vq_get_backend(vq);
1526     if (sock != oldsock) {
1527         ubufs = vhost_net_ubuf_alloc(vq,
1528                          sock && vhost_sock_zcopy(sock));
1529         if (IS_ERR(ubufs)) {
1530             r = PTR_ERR(ubufs);
1531             goto err_ubufs;
1532         }
1533 
1534         vhost_net_disable_vq(n, vq);
1535         vhost_vq_set_backend(vq, sock);
1536         vhost_net_buf_unproduce(nvq);
1537         r = vhost_vq_init_access(vq);
1538         if (r)
1539             goto err_used;
1540         r = vhost_net_enable_vq(n, vq);
1541         if (r)
1542             goto err_used;
1543         if (index == VHOST_NET_VQ_RX) {
1544             if (sock)
1545                 nvq->rx_ring = get_tap_ptr_ring(sock->file);
1546             else
1547                 nvq->rx_ring = NULL;
1548         }
1549 
1550         oldubufs = nvq->ubufs;
1551         nvq->ubufs = ubufs;
1552 
1553         n->tx_packets = 0;
1554         n->tx_zcopy_err = 0;
1555         n->tx_flush = false;
1556     }
1557 
1558     mutex_unlock(&vq->mutex);
1559 
1560     if (oldubufs) {
1561         vhost_net_ubuf_put_wait_and_free(oldubufs);
1562         mutex_lock(&vq->mutex);
1563         vhost_zerocopy_signal_used(n, vq);
1564         mutex_unlock(&vq->mutex);
1565     }
1566 
1567     if (oldsock) {
1568         vhost_dev_flush(&n->dev);
1569         sockfd_put(oldsock);
1570     }
1571 
1572     mutex_unlock(&n->dev.mutex);
1573     return 0;
1574 
1575 err_used:
1576     vhost_vq_set_backend(vq, oldsock);
1577     vhost_net_enable_vq(n, vq);
1578     if (ubufs)
1579         vhost_net_ubuf_put_wait_and_free(ubufs);
1580 err_ubufs:
1581     if (sock)
1582         sockfd_put(sock);
1583 err_vq:
1584     mutex_unlock(&vq->mutex);
1585 err:
1586     mutex_unlock(&n->dev.mutex);
1587     return r;
1588 }
1589 
1590 static long vhost_net_reset_owner(struct vhost_net *n)
1591 {
1592     struct socket *tx_sock = NULL;
1593     struct socket *rx_sock = NULL;
1594     long err;
1595     struct vhost_iotlb *umem;
1596 
1597     mutex_lock(&n->dev.mutex);
1598     err = vhost_dev_check_owner(&n->dev);
1599     if (err)
1600         goto done;
1601     umem = vhost_dev_reset_owner_prepare();
1602     if (!umem) {
1603         err = -ENOMEM;
1604         goto done;
1605     }
1606     vhost_net_stop(n, &tx_sock, &rx_sock);
1607     vhost_net_flush(n);
1608     vhost_dev_stop(&n->dev);
1609     vhost_dev_reset_owner(&n->dev, umem);
1610     vhost_net_vq_reset(n);
1611 done:
1612     mutex_unlock(&n->dev.mutex);
1613     if (tx_sock)
1614         sockfd_put(tx_sock);
1615     if (rx_sock)
1616         sockfd_put(rx_sock);
1617     return err;
1618 }
1619 
1620 static int vhost_net_set_features(struct vhost_net *n, u64 features)
1621 {
1622     size_t vhost_hlen, sock_hlen, hdr_len;
1623     int i;
1624 
1625     hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
1626                    (1ULL << VIRTIO_F_VERSION_1))) ?
1627             sizeof(struct virtio_net_hdr_mrg_rxbuf) :
1628             sizeof(struct virtio_net_hdr);
1629     if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
1630         /* vhost provides vnet_hdr */
1631         vhost_hlen = hdr_len;
1632         sock_hlen = 0;
1633     } else {
1634         /* socket provides vnet_hdr */
1635         vhost_hlen = 0;
1636         sock_hlen = hdr_len;
1637     }
1638     mutex_lock(&n->dev.mutex);
1639     if ((features & (1 << VHOST_F_LOG_ALL)) &&
1640         !vhost_log_access_ok(&n->dev))
1641         goto out_unlock;
1642 
1643     if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) {
1644         if (vhost_init_device_iotlb(&n->dev, true))
1645             goto out_unlock;
1646     }
1647 
1648     for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
1649         mutex_lock(&n->vqs[i].vq.mutex);
1650         n->vqs[i].vq.acked_features = features;
1651         n->vqs[i].vhost_hlen = vhost_hlen;
1652         n->vqs[i].sock_hlen = sock_hlen;
1653         mutex_unlock(&n->vqs[i].vq.mutex);
1654     }
1655     mutex_unlock(&n->dev.mutex);
1656     return 0;
1657 
1658 out_unlock:
1659     mutex_unlock(&n->dev.mutex);
1660     return -EFAULT;
1661 }
1662 
1663 static long vhost_net_set_owner(struct vhost_net *n)
1664 {
1665     int r;
1666 
1667     mutex_lock(&n->dev.mutex);
1668     if (vhost_dev_has_owner(&n->dev)) {
1669         r = -EBUSY;
1670         goto out;
1671     }
1672     r = vhost_net_set_ubuf_info(n);
1673     if (r)
1674         goto out;
1675     r = vhost_dev_set_owner(&n->dev);
1676     if (r)
1677         vhost_net_clear_ubuf_info(n);
1678     vhost_net_flush(n);
1679 out:
1680     mutex_unlock(&n->dev.mutex);
1681     return r;
1682 }
1683 
1684 static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
1685                 unsigned long arg)
1686 {
1687     struct vhost_net *n = f->private_data;
1688     void __user *argp = (void __user *)arg;
1689     u64 __user *featurep = argp;
1690     struct vhost_vring_file backend;
1691     u64 features;
1692     int r;
1693 
1694     switch (ioctl) {
1695     case VHOST_NET_SET_BACKEND:
1696         if (copy_from_user(&backend, argp, sizeof backend))
1697             return -EFAULT;
1698         return vhost_net_set_backend(n, backend.index, backend.fd);
1699     case VHOST_GET_FEATURES:
1700         features = VHOST_NET_FEATURES;
1701         if (copy_to_user(featurep, &features, sizeof features))
1702             return -EFAULT;
1703         return 0;
1704     case VHOST_SET_FEATURES:
1705         if (copy_from_user(&features, featurep, sizeof features))
1706             return -EFAULT;
1707         if (features & ~VHOST_NET_FEATURES)
1708             return -EOPNOTSUPP;
1709         return vhost_net_set_features(n, features);
1710     case VHOST_GET_BACKEND_FEATURES:
1711         features = VHOST_NET_BACKEND_FEATURES;
1712         if (copy_to_user(featurep, &features, sizeof(features)))
1713             return -EFAULT;
1714         return 0;
1715     case VHOST_SET_BACKEND_FEATURES:
1716         if (copy_from_user(&features, featurep, sizeof(features)))
1717             return -EFAULT;
1718         if (features & ~VHOST_NET_BACKEND_FEATURES)
1719             return -EOPNOTSUPP;
1720         vhost_set_backend_features(&n->dev, features);
1721         return 0;
1722     case VHOST_RESET_OWNER:
1723         return vhost_net_reset_owner(n);
1724     case VHOST_SET_OWNER:
1725         return vhost_net_set_owner(n);
1726     default:
1727         mutex_lock(&n->dev.mutex);
1728         r = vhost_dev_ioctl(&n->dev, ioctl, argp);
1729         if (r == -ENOIOCTLCMD)
1730             r = vhost_vring_ioctl(&n->dev, ioctl, argp);
1731         else
1732             vhost_net_flush(n);
1733         mutex_unlock(&n->dev.mutex);
1734         return r;
1735     }
1736 }
1737 
1738 static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
1739 {
1740     struct file *file = iocb->ki_filp;
1741     struct vhost_net *n = file->private_data;
1742     struct vhost_dev *dev = &n->dev;
1743     int noblock = file->f_flags & O_NONBLOCK;
1744 
1745     return vhost_chr_read_iter(dev, to, noblock);
1746 }
1747 
1748 static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
1749                     struct iov_iter *from)
1750 {
1751     struct file *file = iocb->ki_filp;
1752     struct vhost_net *n = file->private_data;
1753     struct vhost_dev *dev = &n->dev;
1754 
1755     return vhost_chr_write_iter(dev, from);
1756 }
1757 
1758 static __poll_t vhost_net_chr_poll(struct file *file, poll_table *wait)
1759 {
1760     struct vhost_net *n = file->private_data;
1761     struct vhost_dev *dev = &n->dev;
1762 
1763     return vhost_chr_poll(file, dev, wait);
1764 }
1765 
1766 static const struct file_operations vhost_net_fops = {
1767     .owner          = THIS_MODULE,
1768     .release        = vhost_net_release,
1769     .read_iter      = vhost_net_chr_read_iter,
1770     .write_iter     = vhost_net_chr_write_iter,
1771     .poll           = vhost_net_chr_poll,
1772     .unlocked_ioctl = vhost_net_ioctl,
1773     .compat_ioctl   = compat_ptr_ioctl,
1774     .open           = vhost_net_open,
1775     .llseek     = noop_llseek,
1776 };
1777 
1778 static struct miscdevice vhost_net_misc = {
1779     .minor = VHOST_NET_MINOR,
1780     .name = "vhost-net",
1781     .fops = &vhost_net_fops,
1782 };
1783 
1784 static int vhost_net_init(void)
1785 {
1786     if (experimental_zcopytx)
1787         vhost_net_enable_zcopy(VHOST_NET_VQ_TX);
1788     return misc_register(&vhost_net_misc);
1789 }
1790 module_init(vhost_net_init);
1791 
1792 static void vhost_net_exit(void)
1793 {
1794     misc_deregister(&vhost_net_misc);
1795 }
1796 module_exit(vhost_net_exit);
1797 
1798 MODULE_VERSION("0.0.1");
1799 MODULE_LICENSE("GPL v2");
1800 MODULE_AUTHOR("Michael S. Tsirkin");
1801 MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
1802 MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
1803 MODULE_ALIAS("devname:vhost-net");