Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /* Copyright (C) 2009 Red Hat, Inc.
0003  * Copyright (C) 2006 Rusty Russell IBM Corporation
0004  *
0005  * Author: Michael S. Tsirkin <mst@redhat.com>
0006  *
0007  * Inspiration, some code, and most witty comments come from
0008  * Documentation/virtual/lguest/lguest.c, by Rusty Russell
0009  *
0010  * Generic code for virtio server in host kernel.
0011  */
0012 
0013 #include <linux/eventfd.h>
0014 #include <linux/vhost.h>
0015 #include <linux/uio.h>
0016 #include <linux/mm.h>
0017 #include <linux/miscdevice.h>
0018 #include <linux/mutex.h>
0019 #include <linux/poll.h>
0020 #include <linux/file.h>
0021 #include <linux/highmem.h>
0022 #include <linux/slab.h>
0023 #include <linux/vmalloc.h>
0024 #include <linux/kthread.h>
0025 #include <linux/cgroup.h>
0026 #include <linux/module.h>
0027 #include <linux/sort.h>
0028 #include <linux/sched/mm.h>
0029 #include <linux/sched/signal.h>
0030 #include <linux/interval_tree_generic.h>
0031 #include <linux/nospec.h>
0032 #include <linux/kcov.h>
0033 
0034 #include "vhost.h"
0035 
0036 static ushort max_mem_regions = 64;
0037 module_param(max_mem_regions, ushort, 0444);
0038 MODULE_PARM_DESC(max_mem_regions,
0039     "Maximum number of memory regions in memory map. (default: 64)");
0040 static int max_iotlb_entries = 2048;
0041 module_param(max_iotlb_entries, int, 0444);
0042 MODULE_PARM_DESC(max_iotlb_entries,
0043     "Maximum number of iotlb entries. (default: 2048)");
0044 
0045 enum {
0046     VHOST_MEMORY_F_LOG = 0x1,
0047 };
0048 
0049 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
0050 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
0051 
0052 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
0053 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
0054 {
0055     vq->user_be = !virtio_legacy_is_little_endian();
0056 }
0057 
0058 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
0059 {
0060     vq->user_be = true;
0061 }
0062 
0063 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
0064 {
0065     vq->user_be = false;
0066 }
0067 
0068 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
0069 {
0070     struct vhost_vring_state s;
0071 
0072     if (vq->private_data)
0073         return -EBUSY;
0074 
0075     if (copy_from_user(&s, argp, sizeof(s)))
0076         return -EFAULT;
0077 
0078     if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
0079         s.num != VHOST_VRING_BIG_ENDIAN)
0080         return -EINVAL;
0081 
0082     if (s.num == VHOST_VRING_BIG_ENDIAN)
0083         vhost_enable_cross_endian_big(vq);
0084     else
0085         vhost_enable_cross_endian_little(vq);
0086 
0087     return 0;
0088 }
0089 
0090 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
0091                    int __user *argp)
0092 {
0093     struct vhost_vring_state s = {
0094         .index = idx,
0095         .num = vq->user_be
0096     };
0097 
0098     if (copy_to_user(argp, &s, sizeof(s)))
0099         return -EFAULT;
0100 
0101     return 0;
0102 }
0103 
0104 static void vhost_init_is_le(struct vhost_virtqueue *vq)
0105 {
0106     /* Note for legacy virtio: user_be is initialized at reset time
0107      * according to the host endianness. If userspace does not set an
0108      * explicit endianness, the default behavior is native endian, as
0109      * expected by legacy virtio.
0110      */
0111     vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
0112 }
0113 #else
0114 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
0115 {
0116 }
0117 
0118 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
0119 {
0120     return -ENOIOCTLCMD;
0121 }
0122 
0123 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
0124                    int __user *argp)
0125 {
0126     return -ENOIOCTLCMD;
0127 }
0128 
0129 static void vhost_init_is_le(struct vhost_virtqueue *vq)
0130 {
0131     vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
0132         || virtio_legacy_is_little_endian();
0133 }
0134 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
0135 
0136 static void vhost_reset_is_le(struct vhost_virtqueue *vq)
0137 {
0138     vhost_init_is_le(vq);
0139 }
0140 
0141 struct vhost_flush_struct {
0142     struct vhost_work work;
0143     struct completion wait_event;
0144 };
0145 
0146 static void vhost_flush_work(struct vhost_work *work)
0147 {
0148     struct vhost_flush_struct *s;
0149 
0150     s = container_of(work, struct vhost_flush_struct, work);
0151     complete(&s->wait_event);
0152 }
0153 
0154 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
0155                 poll_table *pt)
0156 {
0157     struct vhost_poll *poll;
0158 
0159     poll = container_of(pt, struct vhost_poll, table);
0160     poll->wqh = wqh;
0161     add_wait_queue(wqh, &poll->wait);
0162 }
0163 
0164 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
0165                  void *key)
0166 {
0167     struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
0168     struct vhost_work *work = &poll->work;
0169 
0170     if (!(key_to_poll(key) & poll->mask))
0171         return 0;
0172 
0173     if (!poll->dev->use_worker)
0174         work->fn(work);
0175     else
0176         vhost_poll_queue(poll);
0177 
0178     return 0;
0179 }
0180 
0181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
0182 {
0183     clear_bit(VHOST_WORK_QUEUED, &work->flags);
0184     work->fn = fn;
0185 }
0186 EXPORT_SYMBOL_GPL(vhost_work_init);
0187 
0188 /* Init poll structure */
0189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
0190              __poll_t mask, struct vhost_dev *dev)
0191 {
0192     init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
0193     init_poll_funcptr(&poll->table, vhost_poll_func);
0194     poll->mask = mask;
0195     poll->dev = dev;
0196     poll->wqh = NULL;
0197 
0198     vhost_work_init(&poll->work, fn);
0199 }
0200 EXPORT_SYMBOL_GPL(vhost_poll_init);
0201 
0202 /* Start polling a file. We add ourselves to file's wait queue. The caller must
0203  * keep a reference to a file until after vhost_poll_stop is called. */
0204 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
0205 {
0206     __poll_t mask;
0207 
0208     if (poll->wqh)
0209         return 0;
0210 
0211     mask = vfs_poll(file, &poll->table);
0212     if (mask)
0213         vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
0214     if (mask & EPOLLERR) {
0215         vhost_poll_stop(poll);
0216         return -EINVAL;
0217     }
0218 
0219     return 0;
0220 }
0221 EXPORT_SYMBOL_GPL(vhost_poll_start);
0222 
0223 /* Stop polling a file. After this function returns, it becomes safe to drop the
0224  * file reference. You must also flush afterwards. */
0225 void vhost_poll_stop(struct vhost_poll *poll)
0226 {
0227     if (poll->wqh) {
0228         remove_wait_queue(poll->wqh, &poll->wait);
0229         poll->wqh = NULL;
0230     }
0231 }
0232 EXPORT_SYMBOL_GPL(vhost_poll_stop);
0233 
0234 void vhost_dev_flush(struct vhost_dev *dev)
0235 {
0236     struct vhost_flush_struct flush;
0237 
0238     if (dev->worker) {
0239         init_completion(&flush.wait_event);
0240         vhost_work_init(&flush.work, vhost_flush_work);
0241 
0242         vhost_work_queue(dev, &flush.work);
0243         wait_for_completion(&flush.wait_event);
0244     }
0245 }
0246 EXPORT_SYMBOL_GPL(vhost_dev_flush);
0247 
0248 void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
0249 {
0250     if (!dev->worker)
0251         return;
0252 
0253     if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
0254         /* We can only add the work to the list after we're
0255          * sure it was not in the list.
0256          * test_and_set_bit() implies a memory barrier.
0257          */
0258         llist_add(&work->node, &dev->work_list);
0259         wake_up_process(dev->worker);
0260     }
0261 }
0262 EXPORT_SYMBOL_GPL(vhost_work_queue);
0263 
0264 /* A lockless hint for busy polling code to exit the loop */
0265 bool vhost_has_work(struct vhost_dev *dev)
0266 {
0267     return !llist_empty(&dev->work_list);
0268 }
0269 EXPORT_SYMBOL_GPL(vhost_has_work);
0270 
0271 void vhost_poll_queue(struct vhost_poll *poll)
0272 {
0273     vhost_work_queue(poll->dev, &poll->work);
0274 }
0275 EXPORT_SYMBOL_GPL(vhost_poll_queue);
0276 
0277 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
0278 {
0279     int j;
0280 
0281     for (j = 0; j < VHOST_NUM_ADDRS; j++)
0282         vq->meta_iotlb[j] = NULL;
0283 }
0284 
0285 static void vhost_vq_meta_reset(struct vhost_dev *d)
0286 {
0287     int i;
0288 
0289     for (i = 0; i < d->nvqs; ++i)
0290         __vhost_vq_meta_reset(d->vqs[i]);
0291 }
0292 
0293 static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
0294 {
0295     call_ctx->ctx = NULL;
0296     memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
0297 }
0298 
0299 bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
0300 {
0301     return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
0302 }
0303 EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
0304 
0305 static void vhost_vq_reset(struct vhost_dev *dev,
0306                struct vhost_virtqueue *vq)
0307 {
0308     vq->num = 1;
0309     vq->desc = NULL;
0310     vq->avail = NULL;
0311     vq->used = NULL;
0312     vq->last_avail_idx = 0;
0313     vq->avail_idx = 0;
0314     vq->last_used_idx = 0;
0315     vq->signalled_used = 0;
0316     vq->signalled_used_valid = false;
0317     vq->used_flags = 0;
0318     vq->log_used = false;
0319     vq->log_addr = -1ull;
0320     vq->private_data = NULL;
0321     vq->acked_features = 0;
0322     vq->acked_backend_features = 0;
0323     vq->log_base = NULL;
0324     vq->error_ctx = NULL;
0325     vq->kick = NULL;
0326     vq->log_ctx = NULL;
0327     vhost_disable_cross_endian(vq);
0328     vhost_reset_is_le(vq);
0329     vq->busyloop_timeout = 0;
0330     vq->umem = NULL;
0331     vq->iotlb = NULL;
0332     vhost_vring_call_reset(&vq->call_ctx);
0333     __vhost_vq_meta_reset(vq);
0334 }
0335 
0336 static int vhost_worker(void *data)
0337 {
0338     struct vhost_dev *dev = data;
0339     struct vhost_work *work, *work_next;
0340     struct llist_node *node;
0341 
0342     kthread_use_mm(dev->mm);
0343 
0344     for (;;) {
0345         /* mb paired w/ kthread_stop */
0346         set_current_state(TASK_INTERRUPTIBLE);
0347 
0348         if (kthread_should_stop()) {
0349             __set_current_state(TASK_RUNNING);
0350             break;
0351         }
0352 
0353         node = llist_del_all(&dev->work_list);
0354         if (!node)
0355             schedule();
0356 
0357         node = llist_reverse_order(node);
0358         /* make sure flag is seen after deletion */
0359         smp_wmb();
0360         llist_for_each_entry_safe(work, work_next, node, node) {
0361             clear_bit(VHOST_WORK_QUEUED, &work->flags);
0362             __set_current_state(TASK_RUNNING);
0363             kcov_remote_start_common(dev->kcov_handle);
0364             work->fn(work);
0365             kcov_remote_stop();
0366             if (need_resched())
0367                 schedule();
0368         }
0369     }
0370     kthread_unuse_mm(dev->mm);
0371     return 0;
0372 }
0373 
0374 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
0375 {
0376     kfree(vq->indirect);
0377     vq->indirect = NULL;
0378     kfree(vq->log);
0379     vq->log = NULL;
0380     kfree(vq->heads);
0381     vq->heads = NULL;
0382 }
0383 
0384 /* Helper to allocate iovec buffers for all vqs. */
0385 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
0386 {
0387     struct vhost_virtqueue *vq;
0388     int i;
0389 
0390     for (i = 0; i < dev->nvqs; ++i) {
0391         vq = dev->vqs[i];
0392         vq->indirect = kmalloc_array(UIO_MAXIOV,
0393                          sizeof(*vq->indirect),
0394                          GFP_KERNEL);
0395         vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
0396                     GFP_KERNEL);
0397         vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
0398                       GFP_KERNEL);
0399         if (!vq->indirect || !vq->log || !vq->heads)
0400             goto err_nomem;
0401     }
0402     return 0;
0403 
0404 err_nomem:
0405     for (; i >= 0; --i)
0406         vhost_vq_free_iovecs(dev->vqs[i]);
0407     return -ENOMEM;
0408 }
0409 
0410 static void vhost_dev_free_iovecs(struct vhost_dev *dev)
0411 {
0412     int i;
0413 
0414     for (i = 0; i < dev->nvqs; ++i)
0415         vhost_vq_free_iovecs(dev->vqs[i]);
0416 }
0417 
0418 bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
0419               int pkts, int total_len)
0420 {
0421     struct vhost_dev *dev = vq->dev;
0422 
0423     if ((dev->byte_weight && total_len >= dev->byte_weight) ||
0424         pkts >= dev->weight) {
0425         vhost_poll_queue(&vq->poll);
0426         return true;
0427     }
0428 
0429     return false;
0430 }
0431 EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
0432 
0433 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
0434                    unsigned int num)
0435 {
0436     size_t event __maybe_unused =
0437            vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
0438 
0439     return sizeof(*vq->avail) +
0440            sizeof(*vq->avail->ring) * num + event;
0441 }
0442 
0443 static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
0444                   unsigned int num)
0445 {
0446     size_t event __maybe_unused =
0447            vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
0448 
0449     return sizeof(*vq->used) +
0450            sizeof(*vq->used->ring) * num + event;
0451 }
0452 
0453 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
0454                   unsigned int num)
0455 {
0456     return sizeof(*vq->desc) * num;
0457 }
0458 
0459 void vhost_dev_init(struct vhost_dev *dev,
0460             struct vhost_virtqueue **vqs, int nvqs,
0461             int iov_limit, int weight, int byte_weight,
0462             bool use_worker,
0463             int (*msg_handler)(struct vhost_dev *dev, u32 asid,
0464                        struct vhost_iotlb_msg *msg))
0465 {
0466     struct vhost_virtqueue *vq;
0467     int i;
0468 
0469     dev->vqs = vqs;
0470     dev->nvqs = nvqs;
0471     mutex_init(&dev->mutex);
0472     dev->log_ctx = NULL;
0473     dev->umem = NULL;
0474     dev->iotlb = NULL;
0475     dev->mm = NULL;
0476     dev->worker = NULL;
0477     dev->iov_limit = iov_limit;
0478     dev->weight = weight;
0479     dev->byte_weight = byte_weight;
0480     dev->use_worker = use_worker;
0481     dev->msg_handler = msg_handler;
0482     init_llist_head(&dev->work_list);
0483     init_waitqueue_head(&dev->wait);
0484     INIT_LIST_HEAD(&dev->read_list);
0485     INIT_LIST_HEAD(&dev->pending_list);
0486     spin_lock_init(&dev->iotlb_lock);
0487 
0488 
0489     for (i = 0; i < dev->nvqs; ++i) {
0490         vq = dev->vqs[i];
0491         vq->log = NULL;
0492         vq->indirect = NULL;
0493         vq->heads = NULL;
0494         vq->dev = dev;
0495         mutex_init(&vq->mutex);
0496         vhost_vq_reset(dev, vq);
0497         if (vq->handle_kick)
0498             vhost_poll_init(&vq->poll, vq->handle_kick,
0499                     EPOLLIN, dev);
0500     }
0501 }
0502 EXPORT_SYMBOL_GPL(vhost_dev_init);
0503 
0504 /* Caller should have device mutex */
0505 long vhost_dev_check_owner(struct vhost_dev *dev)
0506 {
0507     /* Are you the owner? If not, I don't think you mean to do that */
0508     return dev->mm == current->mm ? 0 : -EPERM;
0509 }
0510 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
0511 
0512 struct vhost_attach_cgroups_struct {
0513     struct vhost_work work;
0514     struct task_struct *owner;
0515     int ret;
0516 };
0517 
0518 static void vhost_attach_cgroups_work(struct vhost_work *work)
0519 {
0520     struct vhost_attach_cgroups_struct *s;
0521 
0522     s = container_of(work, struct vhost_attach_cgroups_struct, work);
0523     s->ret = cgroup_attach_task_all(s->owner, current);
0524 }
0525 
0526 static int vhost_attach_cgroups(struct vhost_dev *dev)
0527 {
0528     struct vhost_attach_cgroups_struct attach;
0529 
0530     attach.owner = current;
0531     vhost_work_init(&attach.work, vhost_attach_cgroups_work);
0532     vhost_work_queue(dev, &attach.work);
0533     vhost_dev_flush(dev);
0534     return attach.ret;
0535 }
0536 
0537 /* Caller should have device mutex */
0538 bool vhost_dev_has_owner(struct vhost_dev *dev)
0539 {
0540     return dev->mm;
0541 }
0542 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
0543 
0544 static void vhost_attach_mm(struct vhost_dev *dev)
0545 {
0546     /* No owner, become one */
0547     if (dev->use_worker) {
0548         dev->mm = get_task_mm(current);
0549     } else {
0550         /* vDPA device does not use worker thead, so there's
0551          * no need to hold the address space for mm. This help
0552          * to avoid deadlock in the case of mmap() which may
0553          * held the refcnt of the file and depends on release
0554          * method to remove vma.
0555          */
0556         dev->mm = current->mm;
0557         mmgrab(dev->mm);
0558     }
0559 }
0560 
0561 static void vhost_detach_mm(struct vhost_dev *dev)
0562 {
0563     if (!dev->mm)
0564         return;
0565 
0566     if (dev->use_worker)
0567         mmput(dev->mm);
0568     else
0569         mmdrop(dev->mm);
0570 
0571     dev->mm = NULL;
0572 }
0573 
0574 /* Caller should have device mutex */
0575 long vhost_dev_set_owner(struct vhost_dev *dev)
0576 {
0577     struct task_struct *worker;
0578     int err;
0579 
0580     /* Is there an owner already? */
0581     if (vhost_dev_has_owner(dev)) {
0582         err = -EBUSY;
0583         goto err_mm;
0584     }
0585 
0586     vhost_attach_mm(dev);
0587 
0588     dev->kcov_handle = kcov_common_handle();
0589     if (dev->use_worker) {
0590         worker = kthread_create(vhost_worker, dev,
0591                     "vhost-%d", current->pid);
0592         if (IS_ERR(worker)) {
0593             err = PTR_ERR(worker);
0594             goto err_worker;
0595         }
0596 
0597         dev->worker = worker;
0598         wake_up_process(worker); /* avoid contributing to loadavg */
0599 
0600         err = vhost_attach_cgroups(dev);
0601         if (err)
0602             goto err_cgroup;
0603     }
0604 
0605     err = vhost_dev_alloc_iovecs(dev);
0606     if (err)
0607         goto err_cgroup;
0608 
0609     return 0;
0610 err_cgroup:
0611     if (dev->worker) {
0612         kthread_stop(dev->worker);
0613         dev->worker = NULL;
0614     }
0615 err_worker:
0616     vhost_detach_mm(dev);
0617     dev->kcov_handle = 0;
0618 err_mm:
0619     return err;
0620 }
0621 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
0622 
0623 static struct vhost_iotlb *iotlb_alloc(void)
0624 {
0625     return vhost_iotlb_alloc(max_iotlb_entries,
0626                  VHOST_IOTLB_FLAG_RETIRE);
0627 }
0628 
0629 struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
0630 {
0631     return iotlb_alloc();
0632 }
0633 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
0634 
0635 /* Caller should have device mutex */
0636 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
0637 {
0638     int i;
0639 
0640     vhost_dev_cleanup(dev);
0641 
0642     dev->umem = umem;
0643     /* We don't need VQ locks below since vhost_dev_cleanup makes sure
0644      * VQs aren't running.
0645      */
0646     for (i = 0; i < dev->nvqs; ++i)
0647         dev->vqs[i]->umem = umem;
0648 }
0649 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
0650 
0651 void vhost_dev_stop(struct vhost_dev *dev)
0652 {
0653     int i;
0654 
0655     for (i = 0; i < dev->nvqs; ++i) {
0656         if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
0657             vhost_poll_stop(&dev->vqs[i]->poll);
0658     }
0659 
0660     vhost_dev_flush(dev);
0661 }
0662 EXPORT_SYMBOL_GPL(vhost_dev_stop);
0663 
0664 static void vhost_clear_msg(struct vhost_dev *dev)
0665 {
0666     struct vhost_msg_node *node, *n;
0667 
0668     spin_lock(&dev->iotlb_lock);
0669 
0670     list_for_each_entry_safe(node, n, &dev->read_list, node) {
0671         list_del(&node->node);
0672         kfree(node);
0673     }
0674 
0675     list_for_each_entry_safe(node, n, &dev->pending_list, node) {
0676         list_del(&node->node);
0677         kfree(node);
0678     }
0679 
0680     spin_unlock(&dev->iotlb_lock);
0681 }
0682 
0683 void vhost_dev_cleanup(struct vhost_dev *dev)
0684 {
0685     int i;
0686 
0687     for (i = 0; i < dev->nvqs; ++i) {
0688         if (dev->vqs[i]->error_ctx)
0689             eventfd_ctx_put(dev->vqs[i]->error_ctx);
0690         if (dev->vqs[i]->kick)
0691             fput(dev->vqs[i]->kick);
0692         if (dev->vqs[i]->call_ctx.ctx)
0693             eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
0694         vhost_vq_reset(dev, dev->vqs[i]);
0695     }
0696     vhost_dev_free_iovecs(dev);
0697     if (dev->log_ctx)
0698         eventfd_ctx_put(dev->log_ctx);
0699     dev->log_ctx = NULL;
0700     /* No one will access memory at this point */
0701     vhost_iotlb_free(dev->umem);
0702     dev->umem = NULL;
0703     vhost_iotlb_free(dev->iotlb);
0704     dev->iotlb = NULL;
0705     vhost_clear_msg(dev);
0706     wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
0707     WARN_ON(!llist_empty(&dev->work_list));
0708     if (dev->worker) {
0709         kthread_stop(dev->worker);
0710         dev->worker = NULL;
0711         dev->kcov_handle = 0;
0712     }
0713     vhost_detach_mm(dev);
0714 }
0715 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
0716 
0717 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
0718 {
0719     u64 a = addr / VHOST_PAGE_SIZE / 8;
0720 
0721     /* Make sure 64 bit math will not overflow. */
0722     if (a > ULONG_MAX - (unsigned long)log_base ||
0723         a + (unsigned long)log_base > ULONG_MAX)
0724         return false;
0725 
0726     return access_ok(log_base + a,
0727              (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
0728 }
0729 
0730 /* Make sure 64 bit math will not overflow. */
0731 static bool vhost_overflow(u64 uaddr, u64 size)
0732 {
0733     if (uaddr > ULONG_MAX || size > ULONG_MAX)
0734         return true;
0735 
0736     if (!size)
0737         return false;
0738 
0739     return uaddr > ULONG_MAX - size + 1;
0740 }
0741 
0742 /* Caller should have vq mutex and device mutex. */
0743 static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
0744                 int log_all)
0745 {
0746     struct vhost_iotlb_map *map;
0747 
0748     if (!umem)
0749         return false;
0750 
0751     list_for_each_entry(map, &umem->list, link) {
0752         unsigned long a = map->addr;
0753 
0754         if (vhost_overflow(map->addr, map->size))
0755             return false;
0756 
0757 
0758         if (!access_ok((void __user *)a, map->size))
0759             return false;
0760         else if (log_all && !log_access_ok(log_base,
0761                            map->start,
0762                            map->size))
0763             return false;
0764     }
0765     return true;
0766 }
0767 
0768 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
0769                            u64 addr, unsigned int size,
0770                            int type)
0771 {
0772     const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
0773 
0774     if (!map)
0775         return NULL;
0776 
0777     return (void __user *)(uintptr_t)(map->addr + addr - map->start);
0778 }
0779 
0780 /* Can we switch to this memory table? */
0781 /* Caller should have device mutex but not vq mutex */
0782 static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
0783                  int log_all)
0784 {
0785     int i;
0786 
0787     for (i = 0; i < d->nvqs; ++i) {
0788         bool ok;
0789         bool log;
0790 
0791         mutex_lock(&d->vqs[i]->mutex);
0792         log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
0793         /* If ring is inactive, will check when it's enabled. */
0794         if (d->vqs[i]->private_data)
0795             ok = vq_memory_access_ok(d->vqs[i]->log_base,
0796                          umem, log);
0797         else
0798             ok = true;
0799         mutex_unlock(&d->vqs[i]->mutex);
0800         if (!ok)
0801             return false;
0802     }
0803     return true;
0804 }
0805 
0806 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
0807               struct iovec iov[], int iov_size, int access);
0808 
0809 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
0810                   const void *from, unsigned size)
0811 {
0812     int ret;
0813 
0814     if (!vq->iotlb)
0815         return __copy_to_user(to, from, size);
0816     else {
0817         /* This function should be called after iotlb
0818          * prefetch, which means we're sure that all vq
0819          * could be access through iotlb. So -EAGAIN should
0820          * not happen in this case.
0821          */
0822         struct iov_iter t;
0823         void __user *uaddr = vhost_vq_meta_fetch(vq,
0824                      (u64)(uintptr_t)to, size,
0825                      VHOST_ADDR_USED);
0826 
0827         if (uaddr)
0828             return __copy_to_user(uaddr, from, size);
0829 
0830         ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
0831                      ARRAY_SIZE(vq->iotlb_iov),
0832                      VHOST_ACCESS_WO);
0833         if (ret < 0)
0834             goto out;
0835         iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
0836         ret = copy_to_iter(from, size, &t);
0837         if (ret == size)
0838             ret = 0;
0839     }
0840 out:
0841     return ret;
0842 }
0843 
0844 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
0845                 void __user *from, unsigned size)
0846 {
0847     int ret;
0848 
0849     if (!vq->iotlb)
0850         return __copy_from_user(to, from, size);
0851     else {
0852         /* This function should be called after iotlb
0853          * prefetch, which means we're sure that vq
0854          * could be access through iotlb. So -EAGAIN should
0855          * not happen in this case.
0856          */
0857         void __user *uaddr = vhost_vq_meta_fetch(vq,
0858                      (u64)(uintptr_t)from, size,
0859                      VHOST_ADDR_DESC);
0860         struct iov_iter f;
0861 
0862         if (uaddr)
0863             return __copy_from_user(to, uaddr, size);
0864 
0865         ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
0866                      ARRAY_SIZE(vq->iotlb_iov),
0867                      VHOST_ACCESS_RO);
0868         if (ret < 0) {
0869             vq_err(vq, "IOTLB translation failure: uaddr "
0870                    "%p size 0x%llx\n", from,
0871                    (unsigned long long) size);
0872             goto out;
0873         }
0874         iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
0875         ret = copy_from_iter(to, size, &f);
0876         if (ret == size)
0877             ret = 0;
0878     }
0879 
0880 out:
0881     return ret;
0882 }
0883 
0884 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
0885                       void __user *addr, unsigned int size,
0886                       int type)
0887 {
0888     int ret;
0889 
0890     ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
0891                  ARRAY_SIZE(vq->iotlb_iov),
0892                  VHOST_ACCESS_RO);
0893     if (ret < 0) {
0894         vq_err(vq, "IOTLB translation failure: uaddr "
0895             "%p size 0x%llx\n", addr,
0896             (unsigned long long) size);
0897         return NULL;
0898     }
0899 
0900     if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
0901         vq_err(vq, "Non atomic userspace memory access: uaddr "
0902             "%p size 0x%llx\n", addr,
0903             (unsigned long long) size);
0904         return NULL;
0905     }
0906 
0907     return vq->iotlb_iov[0].iov_base;
0908 }
0909 
0910 /* This function should be called after iotlb
0911  * prefetch, which means we're sure that vq
0912  * could be access through iotlb. So -EAGAIN should
0913  * not happen in this case.
0914  */
0915 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
0916                         void __user *addr, unsigned int size,
0917                         int type)
0918 {
0919     void __user *uaddr = vhost_vq_meta_fetch(vq,
0920                  (u64)(uintptr_t)addr, size, type);
0921     if (uaddr)
0922         return uaddr;
0923 
0924     return __vhost_get_user_slow(vq, addr, size, type);
0925 }
0926 
0927 #define vhost_put_user(vq, x, ptr)      \
0928 ({ \
0929     int ret; \
0930     if (!vq->iotlb) { \
0931         ret = __put_user(x, ptr); \
0932     } else { \
0933         __typeof__(ptr) to = \
0934             (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
0935                       sizeof(*ptr), VHOST_ADDR_USED); \
0936         if (to != NULL) \
0937             ret = __put_user(x, to); \
0938         else \
0939             ret = -EFAULT;  \
0940     } \
0941     ret; \
0942 })
0943 
0944 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
0945 {
0946     return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
0947                   vhost_avail_event(vq));
0948 }
0949 
0950 static inline int vhost_put_used(struct vhost_virtqueue *vq,
0951                  struct vring_used_elem *head, int idx,
0952                  int count)
0953 {
0954     return vhost_copy_to_user(vq, vq->used->ring + idx, head,
0955                   count * sizeof(*head));
0956 }
0957 
0958 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
0959 
0960 {
0961     return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
0962                   &vq->used->flags);
0963 }
0964 
0965 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
0966 
0967 {
0968     return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
0969                   &vq->used->idx);
0970 }
0971 
0972 #define vhost_get_user(vq, x, ptr, type)        \
0973 ({ \
0974     int ret; \
0975     if (!vq->iotlb) { \
0976         ret = __get_user(x, ptr); \
0977     } else { \
0978         __typeof__(ptr) from = \
0979             (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
0980                                sizeof(*ptr), \
0981                                type); \
0982         if (from != NULL) \
0983             ret = __get_user(x, from); \
0984         else \
0985             ret = -EFAULT; \
0986     } \
0987     ret; \
0988 })
0989 
0990 #define vhost_get_avail(vq, x, ptr) \
0991     vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
0992 
0993 #define vhost_get_used(vq, x, ptr) \
0994     vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
0995 
0996 static void vhost_dev_lock_vqs(struct vhost_dev *d)
0997 {
0998     int i = 0;
0999     for (i = 0; i < d->nvqs; ++i)
1000         mutex_lock_nested(&d->vqs[i]->mutex, i);
1001 }
1002 
1003 static void vhost_dev_unlock_vqs(struct vhost_dev *d)
1004 {
1005     int i = 0;
1006     for (i = 0; i < d->nvqs; ++i)
1007         mutex_unlock(&d->vqs[i]->mutex);
1008 }
1009 
1010 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
1011                       __virtio16 *idx)
1012 {
1013     return vhost_get_avail(vq, *idx, &vq->avail->idx);
1014 }
1015 
1016 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
1017                        __virtio16 *head, int idx)
1018 {
1019     return vhost_get_avail(vq, *head,
1020                    &vq->avail->ring[idx & (vq->num - 1)]);
1021 }
1022 
1023 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
1024                     __virtio16 *flags)
1025 {
1026     return vhost_get_avail(vq, *flags, &vq->avail->flags);
1027 }
1028 
1029 static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
1030                        __virtio16 *event)
1031 {
1032     return vhost_get_avail(vq, *event, vhost_used_event(vq));
1033 }
1034 
1035 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
1036                      __virtio16 *idx)
1037 {
1038     return vhost_get_used(vq, *idx, &vq->used->idx);
1039 }
1040 
1041 static inline int vhost_get_desc(struct vhost_virtqueue *vq,
1042                  struct vring_desc *desc, int idx)
1043 {
1044     return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
1045 }
1046 
1047 static void vhost_iotlb_notify_vq(struct vhost_dev *d,
1048                   struct vhost_iotlb_msg *msg)
1049 {
1050     struct vhost_msg_node *node, *n;
1051 
1052     spin_lock(&d->iotlb_lock);
1053 
1054     list_for_each_entry_safe(node, n, &d->pending_list, node) {
1055         struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
1056         if (msg->iova <= vq_msg->iova &&
1057             msg->iova + msg->size - 1 >= vq_msg->iova &&
1058             vq_msg->type == VHOST_IOTLB_MISS) {
1059             vhost_poll_queue(&node->vq->poll);
1060             list_del(&node->node);
1061             kfree(node);
1062         }
1063     }
1064 
1065     spin_unlock(&d->iotlb_lock);
1066 }
1067 
1068 static bool umem_access_ok(u64 uaddr, u64 size, int access)
1069 {
1070     unsigned long a = uaddr;
1071 
1072     /* Make sure 64 bit math will not overflow. */
1073     if (vhost_overflow(uaddr, size))
1074         return false;
1075 
1076     if ((access & VHOST_ACCESS_RO) &&
1077         !access_ok((void __user *)a, size))
1078         return false;
1079     if ((access & VHOST_ACCESS_WO) &&
1080         !access_ok((void __user *)a, size))
1081         return false;
1082     return true;
1083 }
1084 
1085 static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1086                    struct vhost_iotlb_msg *msg)
1087 {
1088     int ret = 0;
1089 
1090     if (asid != 0)
1091         return -EINVAL;
1092 
1093     mutex_lock(&dev->mutex);
1094     vhost_dev_lock_vqs(dev);
1095     switch (msg->type) {
1096     case VHOST_IOTLB_UPDATE:
1097         if (!dev->iotlb) {
1098             ret = -EFAULT;
1099             break;
1100         }
1101         if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1102             ret = -EFAULT;
1103             break;
1104         }
1105         vhost_vq_meta_reset(dev);
1106         if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
1107                       msg->iova + msg->size - 1,
1108                       msg->uaddr, msg->perm)) {
1109             ret = -ENOMEM;
1110             break;
1111         }
1112         vhost_iotlb_notify_vq(dev, msg);
1113         break;
1114     case VHOST_IOTLB_INVALIDATE:
1115         if (!dev->iotlb) {
1116             ret = -EFAULT;
1117             break;
1118         }
1119         vhost_vq_meta_reset(dev);
1120         vhost_iotlb_del_range(dev->iotlb, msg->iova,
1121                       msg->iova + msg->size - 1);
1122         break;
1123     default:
1124         ret = -EINVAL;
1125         break;
1126     }
1127 
1128     vhost_dev_unlock_vqs(dev);
1129     mutex_unlock(&dev->mutex);
1130 
1131     return ret;
1132 }
1133 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1134                  struct iov_iter *from)
1135 {
1136     struct vhost_iotlb_msg msg;
1137     size_t offset;
1138     int type, ret;
1139     u32 asid = 0;
1140 
1141     ret = copy_from_iter(&type, sizeof(type), from);
1142     if (ret != sizeof(type)) {
1143         ret = -EINVAL;
1144         goto done;
1145     }
1146 
1147     switch (type) {
1148     case VHOST_IOTLB_MSG:
1149         /* There maybe a hole after type for V1 message type,
1150          * so skip it here.
1151          */
1152         offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1153         break;
1154     case VHOST_IOTLB_MSG_V2:
1155         if (vhost_backend_has_feature(dev->vqs[0],
1156                           VHOST_BACKEND_F_IOTLB_ASID)) {
1157             ret = copy_from_iter(&asid, sizeof(asid), from);
1158             if (ret != sizeof(asid)) {
1159                 ret = -EINVAL;
1160                 goto done;
1161             }
1162             offset = 0;
1163         } else
1164             offset = sizeof(__u32);
1165         break;
1166     default:
1167         ret = -EINVAL;
1168         goto done;
1169     }
1170 
1171     iov_iter_advance(from, offset);
1172     ret = copy_from_iter(&msg, sizeof(msg), from);
1173     if (ret != sizeof(msg)) {
1174         ret = -EINVAL;
1175         goto done;
1176     }
1177 
1178     if ((msg.type == VHOST_IOTLB_UPDATE ||
1179          msg.type == VHOST_IOTLB_INVALIDATE) &&
1180          msg.size == 0) {
1181         ret = -EINVAL;
1182         goto done;
1183     }
1184 
1185     if (dev->msg_handler)
1186         ret = dev->msg_handler(dev, asid, &msg);
1187     else
1188         ret = vhost_process_iotlb_msg(dev, asid, &msg);
1189     if (ret) {
1190         ret = -EFAULT;
1191         goto done;
1192     }
1193 
1194     ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1195           sizeof(struct vhost_msg_v2);
1196 done:
1197     return ret;
1198 }
1199 EXPORT_SYMBOL(vhost_chr_write_iter);
1200 
1201 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1202                 poll_table *wait)
1203 {
1204     __poll_t mask = 0;
1205 
1206     poll_wait(file, &dev->wait, wait);
1207 
1208     if (!list_empty(&dev->read_list))
1209         mask |= EPOLLIN | EPOLLRDNORM;
1210 
1211     return mask;
1212 }
1213 EXPORT_SYMBOL(vhost_chr_poll);
1214 
1215 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1216                 int noblock)
1217 {
1218     DEFINE_WAIT(wait);
1219     struct vhost_msg_node *node;
1220     ssize_t ret = 0;
1221     unsigned size = sizeof(struct vhost_msg);
1222 
1223     if (iov_iter_count(to) < size)
1224         return 0;
1225 
1226     while (1) {
1227         if (!noblock)
1228             prepare_to_wait(&dev->wait, &wait,
1229                     TASK_INTERRUPTIBLE);
1230 
1231         node = vhost_dequeue_msg(dev, &dev->read_list);
1232         if (node)
1233             break;
1234         if (noblock) {
1235             ret = -EAGAIN;
1236             break;
1237         }
1238         if (signal_pending(current)) {
1239             ret = -ERESTARTSYS;
1240             break;
1241         }
1242         if (!dev->iotlb) {
1243             ret = -EBADFD;
1244             break;
1245         }
1246 
1247         schedule();
1248     }
1249 
1250     if (!noblock)
1251         finish_wait(&dev->wait, &wait);
1252 
1253     if (node) {
1254         struct vhost_iotlb_msg *msg;
1255         void *start = &node->msg;
1256 
1257         switch (node->msg.type) {
1258         case VHOST_IOTLB_MSG:
1259             size = sizeof(node->msg);
1260             msg = &node->msg.iotlb;
1261             break;
1262         case VHOST_IOTLB_MSG_V2:
1263             size = sizeof(node->msg_v2);
1264             msg = &node->msg_v2.iotlb;
1265             break;
1266         default:
1267             BUG();
1268             break;
1269         }
1270 
1271         ret = copy_to_iter(start, size, to);
1272         if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1273             kfree(node);
1274             return ret;
1275         }
1276         vhost_enqueue_msg(dev, &dev->pending_list, node);
1277     }
1278 
1279     return ret;
1280 }
1281 EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1282 
1283 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1284 {
1285     struct vhost_dev *dev = vq->dev;
1286     struct vhost_msg_node *node;
1287     struct vhost_iotlb_msg *msg;
1288     bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1289 
1290     node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1291     if (!node)
1292         return -ENOMEM;
1293 
1294     if (v2) {
1295         node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1296         msg = &node->msg_v2.iotlb;
1297     } else {
1298         msg = &node->msg.iotlb;
1299     }
1300 
1301     msg->type = VHOST_IOTLB_MISS;
1302     msg->iova = iova;
1303     msg->perm = access;
1304 
1305     vhost_enqueue_msg(dev, &dev->read_list, node);
1306 
1307     return 0;
1308 }
1309 
1310 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1311              vring_desc_t __user *desc,
1312              vring_avail_t __user *avail,
1313              vring_used_t __user *used)
1314 
1315 {
1316     /* If an IOTLB device is present, the vring addresses are
1317      * GIOVAs. Access validation occurs at prefetch time. */
1318     if (vq->iotlb)
1319         return true;
1320 
1321     return access_ok(desc, vhost_get_desc_size(vq, num)) &&
1322            access_ok(avail, vhost_get_avail_size(vq, num)) &&
1323            access_ok(used, vhost_get_used_size(vq, num));
1324 }
1325 
1326 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1327                  const struct vhost_iotlb_map *map,
1328                  int type)
1329 {
1330     int access = (type == VHOST_ADDR_USED) ?
1331              VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1332 
1333     if (likely(map->perm & access))
1334         vq->meta_iotlb[type] = map;
1335 }
1336 
1337 static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1338                 int access, u64 addr, u64 len, int type)
1339 {
1340     const struct vhost_iotlb_map *map;
1341     struct vhost_iotlb *umem = vq->iotlb;
1342     u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1343 
1344     if (vhost_vq_meta_fetch(vq, addr, len, type))
1345         return true;
1346 
1347     while (len > s) {
1348         map = vhost_iotlb_itree_first(umem, addr, last);
1349         if (map == NULL || map->start > addr) {
1350             vhost_iotlb_miss(vq, addr, access);
1351             return false;
1352         } else if (!(map->perm & access)) {
1353             /* Report the possible access violation by
1354              * request another translation from userspace.
1355              */
1356             return false;
1357         }
1358 
1359         size = map->size - addr + map->start;
1360 
1361         if (orig_addr == addr && size >= len)
1362             vhost_vq_meta_update(vq, map, type);
1363 
1364         s += size;
1365         addr += size;
1366     }
1367 
1368     return true;
1369 }
1370 
1371 int vq_meta_prefetch(struct vhost_virtqueue *vq)
1372 {
1373     unsigned int num = vq->num;
1374 
1375     if (!vq->iotlb)
1376         return 1;
1377 
1378     return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
1379                    vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
1380            iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
1381                    vhost_get_avail_size(vq, num),
1382                    VHOST_ADDR_AVAIL) &&
1383            iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
1384                    vhost_get_used_size(vq, num), VHOST_ADDR_USED);
1385 }
1386 EXPORT_SYMBOL_GPL(vq_meta_prefetch);
1387 
1388 /* Can we log writes? */
1389 /* Caller should have device mutex but not vq mutex */
1390 bool vhost_log_access_ok(struct vhost_dev *dev)
1391 {
1392     return memory_access_ok(dev, dev->umem, 1);
1393 }
1394 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1395 
1396 static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
1397                   void __user *log_base,
1398                   bool log_used,
1399                   u64 log_addr)
1400 {
1401     /* If an IOTLB device is present, log_addr is a GIOVA that
1402      * will never be logged by log_used(). */
1403     if (vq->iotlb)
1404         return true;
1405 
1406     return !log_used || log_access_ok(log_base, log_addr,
1407                       vhost_get_used_size(vq, vq->num));
1408 }
1409 
1410 /* Verify access for write logging. */
1411 /* Caller should have vq mutex and device mutex */
1412 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1413                  void __user *log_base)
1414 {
1415     return vq_memory_access_ok(log_base, vq->umem,
1416                    vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1417         vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
1418 }
1419 
1420 /* Can we start vq? */
1421 /* Caller should have vq mutex and device mutex */
1422 bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1423 {
1424     if (!vq_log_access_ok(vq, vq->log_base))
1425         return false;
1426 
1427     return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1428 }
1429 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1430 
1431 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1432 {
1433     struct vhost_memory mem, *newmem;
1434     struct vhost_memory_region *region;
1435     struct vhost_iotlb *newumem, *oldumem;
1436     unsigned long size = offsetof(struct vhost_memory, regions);
1437     int i;
1438 
1439     if (copy_from_user(&mem, m, size))
1440         return -EFAULT;
1441     if (mem.padding)
1442         return -EOPNOTSUPP;
1443     if (mem.nregions > max_mem_regions)
1444         return -E2BIG;
1445     newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1446             GFP_KERNEL);
1447     if (!newmem)
1448         return -ENOMEM;
1449 
1450     memcpy(newmem, &mem, size);
1451     if (copy_from_user(newmem->regions, m->regions,
1452                flex_array_size(newmem, regions, mem.nregions))) {
1453         kvfree(newmem);
1454         return -EFAULT;
1455     }
1456 
1457     newumem = iotlb_alloc();
1458     if (!newumem) {
1459         kvfree(newmem);
1460         return -ENOMEM;
1461     }
1462 
1463     for (region = newmem->regions;
1464          region < newmem->regions + mem.nregions;
1465          region++) {
1466         if (vhost_iotlb_add_range(newumem,
1467                       region->guest_phys_addr,
1468                       region->guest_phys_addr +
1469                       region->memory_size - 1,
1470                       region->userspace_addr,
1471                       VHOST_MAP_RW))
1472             goto err;
1473     }
1474 
1475     if (!memory_access_ok(d, newumem, 0))
1476         goto err;
1477 
1478     oldumem = d->umem;
1479     d->umem = newumem;
1480 
1481     /* All memory accesses are done under some VQ mutex. */
1482     for (i = 0; i < d->nvqs; ++i) {
1483         mutex_lock(&d->vqs[i]->mutex);
1484         d->vqs[i]->umem = newumem;
1485         mutex_unlock(&d->vqs[i]->mutex);
1486     }
1487 
1488     kvfree(newmem);
1489     vhost_iotlb_free(oldumem);
1490     return 0;
1491 
1492 err:
1493     vhost_iotlb_free(newumem);
1494     kvfree(newmem);
1495     return -EFAULT;
1496 }
1497 
1498 static long vhost_vring_set_num(struct vhost_dev *d,
1499                 struct vhost_virtqueue *vq,
1500                 void __user *argp)
1501 {
1502     struct vhost_vring_state s;
1503 
1504     /* Resizing ring with an active backend?
1505      * You don't want to do that. */
1506     if (vq->private_data)
1507         return -EBUSY;
1508 
1509     if (copy_from_user(&s, argp, sizeof s))
1510         return -EFAULT;
1511 
1512     if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
1513         return -EINVAL;
1514     vq->num = s.num;
1515 
1516     return 0;
1517 }
1518 
1519 static long vhost_vring_set_addr(struct vhost_dev *d,
1520                  struct vhost_virtqueue *vq,
1521                  void __user *argp)
1522 {
1523     struct vhost_vring_addr a;
1524 
1525     if (copy_from_user(&a, argp, sizeof a))
1526         return -EFAULT;
1527     if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
1528         return -EOPNOTSUPP;
1529 
1530     /* For 32bit, verify that the top 32bits of the user
1531        data are set to zero. */
1532     if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1533         (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1534         (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
1535         return -EFAULT;
1536 
1537     /* Make sure it's safe to cast pointers to vring types. */
1538     BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1539     BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1540     if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1541         (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1542         (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
1543         return -EINVAL;
1544 
1545     /* We only verify access here if backend is configured.
1546      * If it is not, we don't as size might not have been setup.
1547      * We will verify when backend is configured. */
1548     if (vq->private_data) {
1549         if (!vq_access_ok(vq, vq->num,
1550             (void __user *)(unsigned long)a.desc_user_addr,
1551             (void __user *)(unsigned long)a.avail_user_addr,
1552             (void __user *)(unsigned long)a.used_user_addr))
1553             return -EINVAL;
1554 
1555         /* Also validate log access for used ring if enabled. */
1556         if (!vq_log_used_access_ok(vq, vq->log_base,
1557                 a.flags & (0x1 << VHOST_VRING_F_LOG),
1558                 a.log_guest_addr))
1559             return -EINVAL;
1560     }
1561 
1562     vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1563     vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1564     vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1565     vq->log_addr = a.log_guest_addr;
1566     vq->used = (void __user *)(unsigned long)a.used_user_addr;
1567 
1568     return 0;
1569 }
1570 
1571 static long vhost_vring_set_num_addr(struct vhost_dev *d,
1572                      struct vhost_virtqueue *vq,
1573                      unsigned int ioctl,
1574                      void __user *argp)
1575 {
1576     long r;
1577 
1578     mutex_lock(&vq->mutex);
1579 
1580     switch (ioctl) {
1581     case VHOST_SET_VRING_NUM:
1582         r = vhost_vring_set_num(d, vq, argp);
1583         break;
1584     case VHOST_SET_VRING_ADDR:
1585         r = vhost_vring_set_addr(d, vq, argp);
1586         break;
1587     default:
1588         BUG();
1589     }
1590 
1591     mutex_unlock(&vq->mutex);
1592 
1593     return r;
1594 }
1595 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1596 {
1597     struct file *eventfp, *filep = NULL;
1598     bool pollstart = false, pollstop = false;
1599     struct eventfd_ctx *ctx = NULL;
1600     u32 __user *idxp = argp;
1601     struct vhost_virtqueue *vq;
1602     struct vhost_vring_state s;
1603     struct vhost_vring_file f;
1604     u32 idx;
1605     long r;
1606 
1607     r = get_user(idx, idxp);
1608     if (r < 0)
1609         return r;
1610     if (idx >= d->nvqs)
1611         return -ENOBUFS;
1612 
1613     idx = array_index_nospec(idx, d->nvqs);
1614     vq = d->vqs[idx];
1615 
1616     if (ioctl == VHOST_SET_VRING_NUM ||
1617         ioctl == VHOST_SET_VRING_ADDR) {
1618         return vhost_vring_set_num_addr(d, vq, ioctl, argp);
1619     }
1620 
1621     mutex_lock(&vq->mutex);
1622 
1623     switch (ioctl) {
1624     case VHOST_SET_VRING_BASE:
1625         /* Moving base with an active backend?
1626          * You don't want to do that. */
1627         if (vq->private_data) {
1628             r = -EBUSY;
1629             break;
1630         }
1631         if (copy_from_user(&s, argp, sizeof s)) {
1632             r = -EFAULT;
1633             break;
1634         }
1635         if (s.num > 0xffff) {
1636             r = -EINVAL;
1637             break;
1638         }
1639         vq->last_avail_idx = s.num;
1640         /* Forget the cached index value. */
1641         vq->avail_idx = vq->last_avail_idx;
1642         break;
1643     case VHOST_GET_VRING_BASE:
1644         s.index = idx;
1645         s.num = vq->last_avail_idx;
1646         if (copy_to_user(argp, &s, sizeof s))
1647             r = -EFAULT;
1648         break;
1649     case VHOST_SET_VRING_KICK:
1650         if (copy_from_user(&f, argp, sizeof f)) {
1651             r = -EFAULT;
1652             break;
1653         }
1654         eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
1655         if (IS_ERR(eventfp)) {
1656             r = PTR_ERR(eventfp);
1657             break;
1658         }
1659         if (eventfp != vq->kick) {
1660             pollstop = (filep = vq->kick) != NULL;
1661             pollstart = (vq->kick = eventfp) != NULL;
1662         } else
1663             filep = eventfp;
1664         break;
1665     case VHOST_SET_VRING_CALL:
1666         if (copy_from_user(&f, argp, sizeof f)) {
1667             r = -EFAULT;
1668             break;
1669         }
1670         ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1671         if (IS_ERR(ctx)) {
1672             r = PTR_ERR(ctx);
1673             break;
1674         }
1675 
1676         swap(ctx, vq->call_ctx.ctx);
1677         break;
1678     case VHOST_SET_VRING_ERR:
1679         if (copy_from_user(&f, argp, sizeof f)) {
1680             r = -EFAULT;
1681             break;
1682         }
1683         ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1684         if (IS_ERR(ctx)) {
1685             r = PTR_ERR(ctx);
1686             break;
1687         }
1688         swap(ctx, vq->error_ctx);
1689         break;
1690     case VHOST_SET_VRING_ENDIAN:
1691         r = vhost_set_vring_endian(vq, argp);
1692         break;
1693     case VHOST_GET_VRING_ENDIAN:
1694         r = vhost_get_vring_endian(vq, idx, argp);
1695         break;
1696     case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1697         if (copy_from_user(&s, argp, sizeof(s))) {
1698             r = -EFAULT;
1699             break;
1700         }
1701         vq->busyloop_timeout = s.num;
1702         break;
1703     case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1704         s.index = idx;
1705         s.num = vq->busyloop_timeout;
1706         if (copy_to_user(argp, &s, sizeof(s)))
1707             r = -EFAULT;
1708         break;
1709     default:
1710         r = -ENOIOCTLCMD;
1711     }
1712 
1713     if (pollstop && vq->handle_kick)
1714         vhost_poll_stop(&vq->poll);
1715 
1716     if (!IS_ERR_OR_NULL(ctx))
1717         eventfd_ctx_put(ctx);
1718     if (filep)
1719         fput(filep);
1720 
1721     if (pollstart && vq->handle_kick)
1722         r = vhost_poll_start(&vq->poll, vq->kick);
1723 
1724     mutex_unlock(&vq->mutex);
1725 
1726     if (pollstop && vq->handle_kick)
1727         vhost_dev_flush(vq->poll.dev);
1728     return r;
1729 }
1730 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
1731 
1732 int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
1733 {
1734     struct vhost_iotlb *niotlb, *oiotlb;
1735     int i;
1736 
1737     niotlb = iotlb_alloc();
1738     if (!niotlb)
1739         return -ENOMEM;
1740 
1741     oiotlb = d->iotlb;
1742     d->iotlb = niotlb;
1743 
1744     for (i = 0; i < d->nvqs; ++i) {
1745         struct vhost_virtqueue *vq = d->vqs[i];
1746 
1747         mutex_lock(&vq->mutex);
1748         vq->iotlb = niotlb;
1749         __vhost_vq_meta_reset(vq);
1750         mutex_unlock(&vq->mutex);
1751     }
1752 
1753     vhost_iotlb_free(oiotlb);
1754 
1755     return 0;
1756 }
1757 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1758 
1759 /* Caller must have device mutex */
1760 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1761 {
1762     struct eventfd_ctx *ctx;
1763     u64 p;
1764     long r;
1765     int i, fd;
1766 
1767     /* If you are not the owner, you can become one */
1768     if (ioctl == VHOST_SET_OWNER) {
1769         r = vhost_dev_set_owner(d);
1770         goto done;
1771     }
1772 
1773     /* You must be the owner to do anything else */
1774     r = vhost_dev_check_owner(d);
1775     if (r)
1776         goto done;
1777 
1778     switch (ioctl) {
1779     case VHOST_SET_MEM_TABLE:
1780         r = vhost_set_memory(d, argp);
1781         break;
1782     case VHOST_SET_LOG_BASE:
1783         if (copy_from_user(&p, argp, sizeof p)) {
1784             r = -EFAULT;
1785             break;
1786         }
1787         if ((u64)(unsigned long)p != p) {
1788             r = -EFAULT;
1789             break;
1790         }
1791         for (i = 0; i < d->nvqs; ++i) {
1792             struct vhost_virtqueue *vq;
1793             void __user *base = (void __user *)(unsigned long)p;
1794             vq = d->vqs[i];
1795             mutex_lock(&vq->mutex);
1796             /* If ring is inactive, will check when it's enabled. */
1797             if (vq->private_data && !vq_log_access_ok(vq, base))
1798                 r = -EFAULT;
1799             else
1800                 vq->log_base = base;
1801             mutex_unlock(&vq->mutex);
1802         }
1803         break;
1804     case VHOST_SET_LOG_FD:
1805         r = get_user(fd, (int __user *)argp);
1806         if (r < 0)
1807             break;
1808         ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
1809         if (IS_ERR(ctx)) {
1810             r = PTR_ERR(ctx);
1811             break;
1812         }
1813         swap(ctx, d->log_ctx);
1814         for (i = 0; i < d->nvqs; ++i) {
1815             mutex_lock(&d->vqs[i]->mutex);
1816             d->vqs[i]->log_ctx = d->log_ctx;
1817             mutex_unlock(&d->vqs[i]->mutex);
1818         }
1819         if (ctx)
1820             eventfd_ctx_put(ctx);
1821         break;
1822     default:
1823         r = -ENOIOCTLCMD;
1824         break;
1825     }
1826 done:
1827     return r;
1828 }
1829 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1830 
1831 /* TODO: This is really inefficient.  We need something like get_user()
1832  * (instruction directly accesses the data, with an exception table entry
1833  * returning -EFAULT). See Documentation/x86/exception-tables.rst.
1834  */
1835 static int set_bit_to_user(int nr, void __user *addr)
1836 {
1837     unsigned long log = (unsigned long)addr;
1838     struct page *page;
1839     void *base;
1840     int bit = nr + (log % PAGE_SIZE) * 8;
1841     int r;
1842 
1843     r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page);
1844     if (r < 0)
1845         return r;
1846     BUG_ON(r != 1);
1847     base = kmap_atomic(page);
1848     set_bit(bit, base);
1849     kunmap_atomic(base);
1850     unpin_user_pages_dirty_lock(&page, 1, true);
1851     return 0;
1852 }
1853 
1854 static int log_write(void __user *log_base,
1855              u64 write_address, u64 write_length)
1856 {
1857     u64 write_page = write_address / VHOST_PAGE_SIZE;
1858     int r;
1859 
1860     if (!write_length)
1861         return 0;
1862     write_length += write_address % VHOST_PAGE_SIZE;
1863     for (;;) {
1864         u64 base = (u64)(unsigned long)log_base;
1865         u64 log = base + write_page / 8;
1866         int bit = write_page % 8;
1867         if ((u64)(unsigned long)log != log)
1868             return -EFAULT;
1869         r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
1870         if (r < 0)
1871             return r;
1872         if (write_length <= VHOST_PAGE_SIZE)
1873             break;
1874         write_length -= VHOST_PAGE_SIZE;
1875         write_page += 1;
1876     }
1877     return r;
1878 }
1879 
1880 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
1881 {
1882     struct vhost_iotlb *umem = vq->umem;
1883     struct vhost_iotlb_map *u;
1884     u64 start, end, l, min;
1885     int r;
1886     bool hit = false;
1887 
1888     while (len) {
1889         min = len;
1890         /* More than one GPAs can be mapped into a single HVA. So
1891          * iterate all possible umems here to be safe.
1892          */
1893         list_for_each_entry(u, &umem->list, link) {
1894             if (u->addr > hva - 1 + len ||
1895                 u->addr - 1 + u->size < hva)
1896                 continue;
1897             start = max(u->addr, hva);
1898             end = min(u->addr - 1 + u->size, hva - 1 + len);
1899             l = end - start + 1;
1900             r = log_write(vq->log_base,
1901                       u->start + start - u->addr,
1902                       l);
1903             if (r < 0)
1904                 return r;
1905             hit = true;
1906             min = min(l, min);
1907         }
1908 
1909         if (!hit)
1910             return -EFAULT;
1911 
1912         len -= min;
1913         hva += min;
1914     }
1915 
1916     return 0;
1917 }
1918 
1919 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
1920 {
1921     struct iovec *iov = vq->log_iov;
1922     int i, ret;
1923 
1924     if (!vq->iotlb)
1925         return log_write(vq->log_base, vq->log_addr + used_offset, len);
1926 
1927     ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
1928                  len, iov, 64, VHOST_ACCESS_WO);
1929     if (ret < 0)
1930         return ret;
1931 
1932     for (i = 0; i < ret; i++) {
1933         ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1934                     iov[i].iov_len);
1935         if (ret)
1936             return ret;
1937     }
1938 
1939     return 0;
1940 }
1941 
1942 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
1943             unsigned int log_num, u64 len, struct iovec *iov, int count)
1944 {
1945     int i, r;
1946 
1947     /* Make sure data written is seen before log. */
1948     smp_wmb();
1949 
1950     if (vq->iotlb) {
1951         for (i = 0; i < count; i++) {
1952             r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1953                       iov[i].iov_len);
1954             if (r < 0)
1955                 return r;
1956         }
1957         return 0;
1958     }
1959 
1960     for (i = 0; i < log_num; ++i) {
1961         u64 l = min(log[i].len, len);
1962         r = log_write(vq->log_base, log[i].addr, l);
1963         if (r < 0)
1964             return r;
1965         len -= l;
1966         if (!len) {
1967             if (vq->log_ctx)
1968                 eventfd_signal(vq->log_ctx, 1);
1969             return 0;
1970         }
1971     }
1972     /* Length written exceeds what we have stored. This is a bug. */
1973     BUG();
1974     return 0;
1975 }
1976 EXPORT_SYMBOL_GPL(vhost_log_write);
1977 
1978 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1979 {
1980     void __user *used;
1981     if (vhost_put_used_flags(vq))
1982         return -EFAULT;
1983     if (unlikely(vq->log_used)) {
1984         /* Make sure the flag is seen before log. */
1985         smp_wmb();
1986         /* Log used flag write. */
1987         used = &vq->used->flags;
1988         log_used(vq, (used - (void __user *)vq->used),
1989              sizeof vq->used->flags);
1990         if (vq->log_ctx)
1991             eventfd_signal(vq->log_ctx, 1);
1992     }
1993     return 0;
1994 }
1995 
1996 static int vhost_update_avail_event(struct vhost_virtqueue *vq)
1997 {
1998     if (vhost_put_avail_event(vq))
1999         return -EFAULT;
2000     if (unlikely(vq->log_used)) {
2001         void __user *used;
2002         /* Make sure the event is seen before log. */
2003         smp_wmb();
2004         /* Log avail event write */
2005         used = vhost_avail_event(vq);
2006         log_used(vq, (used - (void __user *)vq->used),
2007              sizeof *vhost_avail_event(vq));
2008         if (vq->log_ctx)
2009             eventfd_signal(vq->log_ctx, 1);
2010     }
2011     return 0;
2012 }
2013 
2014 int vhost_vq_init_access(struct vhost_virtqueue *vq)
2015 {
2016     __virtio16 last_used_idx;
2017     int r;
2018     bool is_le = vq->is_le;
2019 
2020     if (!vq->private_data)
2021         return 0;
2022 
2023     vhost_init_is_le(vq);
2024 
2025     r = vhost_update_used_flags(vq);
2026     if (r)
2027         goto err;
2028     vq->signalled_used_valid = false;
2029     if (!vq->iotlb &&
2030         !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
2031         r = -EFAULT;
2032         goto err;
2033     }
2034     r = vhost_get_used_idx(vq, &last_used_idx);
2035     if (r) {
2036         vq_err(vq, "Can't access used idx at %p\n",
2037                &vq->used->idx);
2038         goto err;
2039     }
2040     vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
2041     return 0;
2042 
2043 err:
2044     vq->is_le = is_le;
2045     return r;
2046 }
2047 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
2048 
2049 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
2050               struct iovec iov[], int iov_size, int access)
2051 {
2052     const struct vhost_iotlb_map *map;
2053     struct vhost_dev *dev = vq->dev;
2054     struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
2055     struct iovec *_iov;
2056     u64 s = 0;
2057     int ret = 0;
2058 
2059     while ((u64)len > s) {
2060         u64 size;
2061         if (unlikely(ret >= iov_size)) {
2062             ret = -ENOBUFS;
2063             break;
2064         }
2065 
2066         map = vhost_iotlb_itree_first(umem, addr, addr + len - 1);
2067         if (map == NULL || map->start > addr) {
2068             if (umem != dev->iotlb) {
2069                 ret = -EFAULT;
2070                 break;
2071             }
2072             ret = -EAGAIN;
2073             break;
2074         } else if (!(map->perm & access)) {
2075             ret = -EPERM;
2076             break;
2077         }
2078 
2079         _iov = iov + ret;
2080         size = map->size - addr + map->start;
2081         _iov->iov_len = min((u64)len - s, size);
2082         _iov->iov_base = (void __user *)(unsigned long)
2083                  (map->addr + addr - map->start);
2084         s += size;
2085         addr += size;
2086         ++ret;
2087     }
2088 
2089     if (ret == -EAGAIN)
2090         vhost_iotlb_miss(vq, addr, access);
2091     return ret;
2092 }
2093 
2094 /* Each buffer in the virtqueues is actually a chain of descriptors.  This
2095  * function returns the next descriptor in the chain,
2096  * or -1U if we're at the end. */
2097 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
2098 {
2099     unsigned int next;
2100 
2101     /* If this descriptor says it doesn't chain, we're done. */
2102     if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
2103         return -1U;
2104 
2105     /* Check they're not leading us off end of descriptors. */
2106     next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
2107     return next;
2108 }
2109 
2110 static int get_indirect(struct vhost_virtqueue *vq,
2111             struct iovec iov[], unsigned int iov_size,
2112             unsigned int *out_num, unsigned int *in_num,
2113             struct vhost_log *log, unsigned int *log_num,
2114             struct vring_desc *indirect)
2115 {
2116     struct vring_desc desc;
2117     unsigned int i = 0, count, found = 0;
2118     u32 len = vhost32_to_cpu(vq, indirect->len);
2119     struct iov_iter from;
2120     int ret, access;
2121 
2122     /* Sanity check */
2123     if (unlikely(len % sizeof desc)) {
2124         vq_err(vq, "Invalid length in indirect descriptor: "
2125                "len 0x%llx not multiple of 0x%zx\n",
2126                (unsigned long long)len,
2127                sizeof desc);
2128         return -EINVAL;
2129     }
2130 
2131     ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
2132                  UIO_MAXIOV, VHOST_ACCESS_RO);
2133     if (unlikely(ret < 0)) {
2134         if (ret != -EAGAIN)
2135             vq_err(vq, "Translation failure %d in indirect.\n", ret);
2136         return ret;
2137     }
2138     iov_iter_init(&from, READ, vq->indirect, ret, len);
2139     count = len / sizeof desc;
2140     /* Buffers are chained via a 16 bit next field, so
2141      * we can have at most 2^16 of these. */
2142     if (unlikely(count > USHRT_MAX + 1)) {
2143         vq_err(vq, "Indirect buffer length too big: %d\n",
2144                indirect->len);
2145         return -E2BIG;
2146     }
2147 
2148     do {
2149         unsigned iov_count = *in_num + *out_num;
2150         if (unlikely(++found > count)) {
2151             vq_err(vq, "Loop detected: last one at %u "
2152                    "indirect size %u\n",
2153                    i, count);
2154             return -EINVAL;
2155         }
2156         if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
2157             vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2158                    i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2159             return -EINVAL;
2160         }
2161         if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2162             vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2163                    i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2164             return -EINVAL;
2165         }
2166 
2167         if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2168             access = VHOST_ACCESS_WO;
2169         else
2170             access = VHOST_ACCESS_RO;
2171 
2172         ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2173                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2174                      iov_size - iov_count, access);
2175         if (unlikely(ret < 0)) {
2176             if (ret != -EAGAIN)
2177                 vq_err(vq, "Translation failure %d indirect idx %d\n",
2178                     ret, i);
2179             return ret;
2180         }
2181         /* If this is an input descriptor, increment that count. */
2182         if (access == VHOST_ACCESS_WO) {
2183             *in_num += ret;
2184             if (unlikely(log && ret)) {
2185                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2186                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2187                 ++*log_num;
2188             }
2189         } else {
2190             /* If it's an output descriptor, they're all supposed
2191              * to come before any input descriptors. */
2192             if (unlikely(*in_num)) {
2193                 vq_err(vq, "Indirect descriptor "
2194                        "has out after in: idx %d\n", i);
2195                 return -EINVAL;
2196             }
2197             *out_num += ret;
2198         }
2199     } while ((i = next_desc(vq, &desc)) != -1);
2200     return 0;
2201 }
2202 
2203 /* This looks in the virtqueue and for the first available buffer, and converts
2204  * it to an iovec for convenient access.  Since descriptors consist of some
2205  * number of output then some number of input descriptors, it's actually two
2206  * iovecs, but we pack them into one and note how many of each there were.
2207  *
2208  * This function returns the descriptor number found, or vq->num (which is
2209  * never a valid descriptor number) if none was found.  A negative code is
2210  * returned on error. */
2211 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2212               struct iovec iov[], unsigned int iov_size,
2213               unsigned int *out_num, unsigned int *in_num,
2214               struct vhost_log *log, unsigned int *log_num)
2215 {
2216     struct vring_desc desc;
2217     unsigned int i, head, found = 0;
2218     u16 last_avail_idx;
2219     __virtio16 avail_idx;
2220     __virtio16 ring_head;
2221     int ret, access;
2222 
2223     /* Check it isn't doing very strange things with descriptor numbers. */
2224     last_avail_idx = vq->last_avail_idx;
2225 
2226     if (vq->avail_idx == vq->last_avail_idx) {
2227         if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
2228             vq_err(vq, "Failed to access avail idx at %p\n",
2229                 &vq->avail->idx);
2230             return -EFAULT;
2231         }
2232         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2233 
2234         if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2235             vq_err(vq, "Guest moved used index from %u to %u",
2236                 last_avail_idx, vq->avail_idx);
2237             return -EFAULT;
2238         }
2239 
2240         /* If there's nothing new since last we looked, return
2241          * invalid.
2242          */
2243         if (vq->avail_idx == last_avail_idx)
2244             return vq->num;
2245 
2246         /* Only get avail ring entries after they have been
2247          * exposed by guest.
2248          */
2249         smp_rmb();
2250     }
2251 
2252     /* Grab the next descriptor number they're advertising, and increment
2253      * the index we've seen. */
2254     if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
2255         vq_err(vq, "Failed to read head: idx %d address %p\n",
2256                last_avail_idx,
2257                &vq->avail->ring[last_avail_idx % vq->num]);
2258         return -EFAULT;
2259     }
2260 
2261     head = vhost16_to_cpu(vq, ring_head);
2262 
2263     /* If their number is silly, that's an error. */
2264     if (unlikely(head >= vq->num)) {
2265         vq_err(vq, "Guest says index %u > %u is available",
2266                head, vq->num);
2267         return -EINVAL;
2268     }
2269 
2270     /* When we start there are none of either input nor output. */
2271     *out_num = *in_num = 0;
2272     if (unlikely(log))
2273         *log_num = 0;
2274 
2275     i = head;
2276     do {
2277         unsigned iov_count = *in_num + *out_num;
2278         if (unlikely(i >= vq->num)) {
2279             vq_err(vq, "Desc index is %u > %u, head = %u",
2280                    i, vq->num, head);
2281             return -EINVAL;
2282         }
2283         if (unlikely(++found > vq->num)) {
2284             vq_err(vq, "Loop detected: last one at %u "
2285                    "vq size %u head %u\n",
2286                    i, vq->num, head);
2287             return -EINVAL;
2288         }
2289         ret = vhost_get_desc(vq, &desc, i);
2290         if (unlikely(ret)) {
2291             vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2292                    i, vq->desc + i);
2293             return -EFAULT;
2294         }
2295         if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2296             ret = get_indirect(vq, iov, iov_size,
2297                        out_num, in_num,
2298                        log, log_num, &desc);
2299             if (unlikely(ret < 0)) {
2300                 if (ret != -EAGAIN)
2301                     vq_err(vq, "Failure detected "
2302                         "in indirect descriptor at idx %d\n", i);
2303                 return ret;
2304             }
2305             continue;
2306         }
2307 
2308         if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2309             access = VHOST_ACCESS_WO;
2310         else
2311             access = VHOST_ACCESS_RO;
2312         ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2313                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2314                      iov_size - iov_count, access);
2315         if (unlikely(ret < 0)) {
2316             if (ret != -EAGAIN)
2317                 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2318                     ret, i);
2319             return ret;
2320         }
2321         if (access == VHOST_ACCESS_WO) {
2322             /* If this is an input descriptor,
2323              * increment that count. */
2324             *in_num += ret;
2325             if (unlikely(log && ret)) {
2326                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2327                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2328                 ++*log_num;
2329             }
2330         } else {
2331             /* If it's an output descriptor, they're all supposed
2332              * to come before any input descriptors. */
2333             if (unlikely(*in_num)) {
2334                 vq_err(vq, "Descriptor has out after in: "
2335                        "idx %d\n", i);
2336                 return -EINVAL;
2337             }
2338             *out_num += ret;
2339         }
2340     } while ((i = next_desc(vq, &desc)) != -1);
2341 
2342     /* On success, increment avail index. */
2343     vq->last_avail_idx++;
2344 
2345     /* Assume notifications from guest are disabled at this point,
2346      * if they aren't we would need to update avail_event index. */
2347     BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2348     return head;
2349 }
2350 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2351 
2352 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2353 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2354 {
2355     vq->last_avail_idx -= n;
2356 }
2357 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2358 
2359 /* After we've used one of their buffers, we tell them about it.  We'll then
2360  * want to notify the guest, using eventfd. */
2361 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2362 {
2363     struct vring_used_elem heads = {
2364         cpu_to_vhost32(vq, head),
2365         cpu_to_vhost32(vq, len)
2366     };
2367 
2368     return vhost_add_used_n(vq, &heads, 1);
2369 }
2370 EXPORT_SYMBOL_GPL(vhost_add_used);
2371 
2372 static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2373                 struct vring_used_elem *heads,
2374                 unsigned count)
2375 {
2376     vring_used_elem_t __user *used;
2377     u16 old, new;
2378     int start;
2379 
2380     start = vq->last_used_idx & (vq->num - 1);
2381     used = vq->used->ring + start;
2382     if (vhost_put_used(vq, heads, start, count)) {
2383         vq_err(vq, "Failed to write used");
2384         return -EFAULT;
2385     }
2386     if (unlikely(vq->log_used)) {
2387         /* Make sure data is seen before log. */
2388         smp_wmb();
2389         /* Log used ring entry write. */
2390         log_used(vq, ((void __user *)used - (void __user *)vq->used),
2391              count * sizeof *used);
2392     }
2393     old = vq->last_used_idx;
2394     new = (vq->last_used_idx += count);
2395     /* If the driver never bothers to signal in a very long while,
2396      * used index might wrap around. If that happens, invalidate
2397      * signalled_used index we stored. TODO: make sure driver
2398      * signals at least once in 2^16 and remove this. */
2399     if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2400         vq->signalled_used_valid = false;
2401     return 0;
2402 }
2403 
2404 /* After we've used one of their buffers, we tell them about it.  We'll then
2405  * want to notify the guest, using eventfd. */
2406 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2407              unsigned count)
2408 {
2409     int start, n, r;
2410 
2411     start = vq->last_used_idx & (vq->num - 1);
2412     n = vq->num - start;
2413     if (n < count) {
2414         r = __vhost_add_used_n(vq, heads, n);
2415         if (r < 0)
2416             return r;
2417         heads += n;
2418         count -= n;
2419     }
2420     r = __vhost_add_used_n(vq, heads, count);
2421 
2422     /* Make sure buffer is written before we update index. */
2423     smp_wmb();
2424     if (vhost_put_used_idx(vq)) {
2425         vq_err(vq, "Failed to increment used idx");
2426         return -EFAULT;
2427     }
2428     if (unlikely(vq->log_used)) {
2429         /* Make sure used idx is seen before log. */
2430         smp_wmb();
2431         /* Log used index update. */
2432         log_used(vq, offsetof(struct vring_used, idx),
2433              sizeof vq->used->idx);
2434         if (vq->log_ctx)
2435             eventfd_signal(vq->log_ctx, 1);
2436     }
2437     return r;
2438 }
2439 EXPORT_SYMBOL_GPL(vhost_add_used_n);
2440 
2441 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2442 {
2443     __u16 old, new;
2444     __virtio16 event;
2445     bool v;
2446     /* Flush out used index updates. This is paired
2447      * with the barrier that the Guest executes when enabling
2448      * interrupts. */
2449     smp_mb();
2450 
2451     if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2452         unlikely(vq->avail_idx == vq->last_avail_idx))
2453         return true;
2454 
2455     if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2456         __virtio16 flags;
2457         if (vhost_get_avail_flags(vq, &flags)) {
2458             vq_err(vq, "Failed to get flags");
2459             return true;
2460         }
2461         return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2462     }
2463     old = vq->signalled_used;
2464     v = vq->signalled_used_valid;
2465     new = vq->signalled_used = vq->last_used_idx;
2466     vq->signalled_used_valid = true;
2467 
2468     if (unlikely(!v))
2469         return true;
2470 
2471     if (vhost_get_used_event(vq, &event)) {
2472         vq_err(vq, "Failed to get used event idx");
2473         return true;
2474     }
2475     return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2476 }
2477 
2478 /* This actually signals the guest, using eventfd. */
2479 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2480 {
2481     /* Signal the Guest tell them we used something up. */
2482     if (vq->call_ctx.ctx && vhost_notify(dev, vq))
2483         eventfd_signal(vq->call_ctx.ctx, 1);
2484 }
2485 EXPORT_SYMBOL_GPL(vhost_signal);
2486 
2487 /* And here's the combo meal deal.  Supersize me! */
2488 void vhost_add_used_and_signal(struct vhost_dev *dev,
2489                    struct vhost_virtqueue *vq,
2490                    unsigned int head, int len)
2491 {
2492     vhost_add_used(vq, head, len);
2493     vhost_signal(dev, vq);
2494 }
2495 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2496 
2497 /* multi-buffer version of vhost_add_used_and_signal */
2498 void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2499                  struct vhost_virtqueue *vq,
2500                  struct vring_used_elem *heads, unsigned count)
2501 {
2502     vhost_add_used_n(vq, heads, count);
2503     vhost_signal(dev, vq);
2504 }
2505 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2506 
2507 /* return true if we're sure that avaiable ring is empty */
2508 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2509 {
2510     __virtio16 avail_idx;
2511     int r;
2512 
2513     if (vq->avail_idx != vq->last_avail_idx)
2514         return false;
2515 
2516     r = vhost_get_avail_idx(vq, &avail_idx);
2517     if (unlikely(r))
2518         return false;
2519     vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2520 
2521     return vq->avail_idx == vq->last_avail_idx;
2522 }
2523 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2524 
2525 /* OK, now we need to know about added descriptors. */
2526 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2527 {
2528     __virtio16 avail_idx;
2529     int r;
2530 
2531     if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2532         return false;
2533     vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2534     if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2535         r = vhost_update_used_flags(vq);
2536         if (r) {
2537             vq_err(vq, "Failed to enable notification at %p: %d\n",
2538                    &vq->used->flags, r);
2539             return false;
2540         }
2541     } else {
2542         r = vhost_update_avail_event(vq);
2543         if (r) {
2544             vq_err(vq, "Failed to update avail event index at %p: %d\n",
2545                    vhost_avail_event(vq), r);
2546             return false;
2547         }
2548     }
2549     /* They could have slipped one in as we were doing that: make
2550      * sure it's written, then check again. */
2551     smp_mb();
2552     r = vhost_get_avail_idx(vq, &avail_idx);
2553     if (r) {
2554         vq_err(vq, "Failed to check avail idx at %p: %d\n",
2555                &vq->avail->idx, r);
2556         return false;
2557     }
2558     vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2559 
2560     return vq->avail_idx != vq->last_avail_idx;
2561 }
2562 EXPORT_SYMBOL_GPL(vhost_enable_notify);
2563 
2564 /* We don't need to be notified again. */
2565 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2566 {
2567     int r;
2568 
2569     if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2570         return;
2571     vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2572     if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2573         r = vhost_update_used_flags(vq);
2574         if (r)
2575             vq_err(vq, "Failed to disable notification at %p: %d\n",
2576                    &vq->used->flags, r);
2577     }
2578 }
2579 EXPORT_SYMBOL_GPL(vhost_disable_notify);
2580 
2581 /* Create a new message. */
2582 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2583 {
2584     struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
2585     if (!node)
2586         return NULL;
2587 
2588     /* Make sure all padding within the structure is initialized. */
2589     memset(&node->msg, 0, sizeof node->msg);
2590     node->vq = vq;
2591     node->msg.type = type;
2592     return node;
2593 }
2594 EXPORT_SYMBOL_GPL(vhost_new_msg);
2595 
2596 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2597                struct vhost_msg_node *node)
2598 {
2599     spin_lock(&dev->iotlb_lock);
2600     list_add_tail(&node->node, head);
2601     spin_unlock(&dev->iotlb_lock);
2602 
2603     wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2604 }
2605 EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2606 
2607 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2608                      struct list_head *head)
2609 {
2610     struct vhost_msg_node *node = NULL;
2611 
2612     spin_lock(&dev->iotlb_lock);
2613     if (!list_empty(head)) {
2614         node = list_first_entry(head, struct vhost_msg_node,
2615                     node);
2616         list_del(&node->node);
2617     }
2618     spin_unlock(&dev->iotlb_lock);
2619 
2620     return node;
2621 }
2622 EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2623 
2624 void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
2625 {
2626     struct vhost_virtqueue *vq;
2627     int i;
2628 
2629     mutex_lock(&dev->mutex);
2630     for (i = 0; i < dev->nvqs; ++i) {
2631         vq = dev->vqs[i];
2632         mutex_lock(&vq->mutex);
2633         vq->acked_backend_features = features;
2634         mutex_unlock(&vq->mutex);
2635     }
2636     mutex_unlock(&dev->mutex);
2637 }
2638 EXPORT_SYMBOL_GPL(vhost_set_backend_features);
2639 
2640 static int __init vhost_init(void)
2641 {
2642     return 0;
2643 }
2644 
2645 static void __exit vhost_exit(void)
2646 {
2647 }
2648 
2649 module_init(vhost_init);
2650 module_exit(vhost_exit);
2651 
2652 MODULE_VERSION("0.0.1");
2653 MODULE_LICENSE("GPL v2");
2654 MODULE_AUTHOR("Michael S. Tsirkin");
2655 MODULE_DESCRIPTION("Host kernel accelerator for virtio");