Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 #define _GNU_SOURCE
0003 #include <getopt.h>
0004 #include <limits.h>
0005 #include <string.h>
0006 #include <poll.h>
0007 #include <sys/eventfd.h>
0008 #include <stdlib.h>
0009 #include <assert.h>
0010 #include <unistd.h>
0011 #include <sys/ioctl.h>
0012 #include <sys/stat.h>
0013 #include <sys/types.h>
0014 #include <fcntl.h>
0015 #include <stdbool.h>
0016 #include <linux/virtio_types.h>
0017 #include <linux/vhost.h>
0018 #include <linux/virtio.h>
0019 #include <linux/virtio_ring.h>
0020 #include "../../drivers/vhost/test.h"
0021 
0022 #define RANDOM_BATCH -1
0023 
0024 /* Unused */
0025 void *__kmalloc_fake, *__kfree_ignore_start, *__kfree_ignore_end;
0026 
0027 struct vq_info {
0028     int kick;
0029     int call;
0030     int num;
0031     int idx;
0032     void *ring;
0033     /* copy used for control */
0034     struct vring vring;
0035     struct virtqueue *vq;
0036 };
0037 
0038 struct vdev_info {
0039     struct virtio_device vdev;
0040     int control;
0041     struct pollfd fds[1];
0042     struct vq_info vqs[1];
0043     int nvqs;
0044     void *buf;
0045     size_t buf_size;
0046     struct vhost_memory *mem;
0047 };
0048 
0049 static const struct vhost_vring_file no_backend = { .fd = -1 },
0050                      backend = { .fd = 1 };
0051 static const struct vhost_vring_state null_state = {};
0052 
0053 bool vq_notify(struct virtqueue *vq)
0054 {
0055     struct vq_info *info = vq->priv;
0056     unsigned long long v = 1;
0057     int r;
0058     r = write(info->kick, &v, sizeof v);
0059     assert(r == sizeof v);
0060     return true;
0061 }
0062 
0063 void vq_callback(struct virtqueue *vq)
0064 {
0065 }
0066 
0067 
0068 void vhost_vq_setup(struct vdev_info *dev, struct vq_info *info)
0069 {
0070     struct vhost_vring_state state = { .index = info->idx };
0071     struct vhost_vring_file file = { .index = info->idx };
0072     unsigned long long features = dev->vdev.features;
0073     struct vhost_vring_addr addr = {
0074         .index = info->idx,
0075         .desc_user_addr = (uint64_t)(unsigned long)info->vring.desc,
0076         .avail_user_addr = (uint64_t)(unsigned long)info->vring.avail,
0077         .used_user_addr = (uint64_t)(unsigned long)info->vring.used,
0078     };
0079     int r;
0080     r = ioctl(dev->control, VHOST_SET_FEATURES, &features);
0081     assert(r >= 0);
0082     state.num = info->vring.num;
0083     r = ioctl(dev->control, VHOST_SET_VRING_NUM, &state);
0084     assert(r >= 0);
0085     state.num = 0;
0086     r = ioctl(dev->control, VHOST_SET_VRING_BASE, &state);
0087     assert(r >= 0);
0088     r = ioctl(dev->control, VHOST_SET_VRING_ADDR, &addr);
0089     assert(r >= 0);
0090     file.fd = info->kick;
0091     r = ioctl(dev->control, VHOST_SET_VRING_KICK, &file);
0092     assert(r >= 0);
0093     file.fd = info->call;
0094     r = ioctl(dev->control, VHOST_SET_VRING_CALL, &file);
0095     assert(r >= 0);
0096 }
0097 
0098 static void vq_reset(struct vq_info *info, int num, struct virtio_device *vdev)
0099 {
0100     if (info->vq)
0101         vring_del_virtqueue(info->vq);
0102 
0103     memset(info->ring, 0, vring_size(num, 4096));
0104     vring_init(&info->vring, num, info->ring, 4096);
0105     info->vq = vring_new_virtqueue(info->idx, num, 4096, vdev, true, false,
0106                        info->ring, vq_notify, vq_callback, "test");
0107     assert(info->vq);
0108     info->vq->priv = info;
0109 }
0110 
0111 static void vq_info_add(struct vdev_info *dev, int num)
0112 {
0113     struct vq_info *info = &dev->vqs[dev->nvqs];
0114     int r;
0115     info->idx = dev->nvqs;
0116     info->kick = eventfd(0, EFD_NONBLOCK);
0117     info->call = eventfd(0, EFD_NONBLOCK);
0118     r = posix_memalign(&info->ring, 4096, vring_size(num, 4096));
0119     assert(r >= 0);
0120     vq_reset(info, num, &dev->vdev);
0121     vhost_vq_setup(dev, info);
0122     dev->fds[info->idx].fd = info->call;
0123     dev->fds[info->idx].events = POLLIN;
0124     dev->nvqs++;
0125 }
0126 
0127 static void vdev_info_init(struct vdev_info* dev, unsigned long long features)
0128 {
0129     int r;
0130     memset(dev, 0, sizeof *dev);
0131     dev->vdev.features = features;
0132     INIT_LIST_HEAD(&dev->vdev.vqs);
0133     spin_lock_init(&dev->vdev.vqs_list_lock);
0134     dev->buf_size = 1024;
0135     dev->buf = malloc(dev->buf_size);
0136     assert(dev->buf);
0137         dev->control = open("/dev/vhost-test", O_RDWR);
0138     assert(dev->control >= 0);
0139     r = ioctl(dev->control, VHOST_SET_OWNER, NULL);
0140     assert(r >= 0);
0141     dev->mem = malloc(offsetof(struct vhost_memory, regions) +
0142               sizeof dev->mem->regions[0]);
0143     assert(dev->mem);
0144     memset(dev->mem, 0, offsetof(struct vhost_memory, regions) +
0145                           sizeof dev->mem->regions[0]);
0146     dev->mem->nregions = 1;
0147     dev->mem->regions[0].guest_phys_addr = (long)dev->buf;
0148     dev->mem->regions[0].userspace_addr = (long)dev->buf;
0149     dev->mem->regions[0].memory_size = dev->buf_size;
0150     r = ioctl(dev->control, VHOST_SET_MEM_TABLE, dev->mem);
0151     assert(r >= 0);
0152 }
0153 
0154 /* TODO: this is pretty bad: we get a cache line bounce
0155  * for the wait queue on poll and another one on read,
0156  * plus the read which is there just to clear the
0157  * current state. */
0158 static void wait_for_interrupt(struct vdev_info *dev)
0159 {
0160     int i;
0161     unsigned long long val;
0162     poll(dev->fds, dev->nvqs, -1);
0163     for (i = 0; i < dev->nvqs; ++i)
0164         if (dev->fds[i].revents & POLLIN) {
0165             read(dev->fds[i].fd, &val, sizeof val);
0166         }
0167 }
0168 
0169 static void run_test(struct vdev_info *dev, struct vq_info *vq,
0170              bool delayed, int batch, int reset_n, int bufs)
0171 {
0172     struct scatterlist sl;
0173     long started = 0, completed = 0, next_reset = reset_n;
0174     long completed_before, started_before;
0175     int r, test = 1;
0176     unsigned len;
0177     long long spurious = 0;
0178     const bool random_batch = batch == RANDOM_BATCH;
0179 
0180     r = ioctl(dev->control, VHOST_TEST_RUN, &test);
0181     assert(r >= 0);
0182     if (!reset_n) {
0183         next_reset = INT_MAX;
0184     }
0185 
0186     for (;;) {
0187         virtqueue_disable_cb(vq->vq);
0188         completed_before = completed;
0189         started_before = started;
0190         do {
0191             const bool reset = completed > next_reset;
0192             if (random_batch)
0193                 batch = (random() % vq->vring.num) + 1;
0194 
0195             while (started < bufs &&
0196                    (started - completed) < batch) {
0197                 sg_init_one(&sl, dev->buf, dev->buf_size);
0198                 r = virtqueue_add_outbuf(vq->vq, &sl, 1,
0199                              dev->buf + started,
0200                              GFP_ATOMIC);
0201                 if (unlikely(r != 0)) {
0202                     if (r == -ENOSPC &&
0203                         started > started_before)
0204                         r = 0;
0205                     else
0206                         r = -1;
0207                     break;
0208                 }
0209 
0210                 ++started;
0211 
0212                 if (unlikely(!virtqueue_kick(vq->vq))) {
0213                     r = -1;
0214                     break;
0215                 }
0216             }
0217 
0218             if (started >= bufs)
0219                 r = -1;
0220 
0221             if (reset) {
0222                 r = ioctl(dev->control, VHOST_TEST_SET_BACKEND,
0223                       &no_backend);
0224                 assert(!r);
0225             }
0226 
0227             /* Flush out completed bufs if any */
0228             while (virtqueue_get_buf(vq->vq, &len)) {
0229                 ++completed;
0230                 r = 0;
0231             }
0232 
0233             if (reset) {
0234                 struct vhost_vring_state s = { .index = 0 };
0235 
0236                 vq_reset(vq, vq->vring.num, &dev->vdev);
0237 
0238                 r = ioctl(dev->control, VHOST_GET_VRING_BASE,
0239                       &s);
0240                 assert(!r);
0241 
0242                 s.num = 0;
0243                 r = ioctl(dev->control, VHOST_SET_VRING_BASE,
0244                       &null_state);
0245                 assert(!r);
0246 
0247                 r = ioctl(dev->control, VHOST_TEST_SET_BACKEND,
0248                       &backend);
0249                 assert(!r);
0250 
0251                 started = completed;
0252                 while (completed > next_reset)
0253                     next_reset += completed;
0254             }
0255         } while (r == 0);
0256         if (completed == completed_before && started == started_before)
0257             ++spurious;
0258         assert(completed <= bufs);
0259         assert(started <= bufs);
0260         if (completed == bufs)
0261             break;
0262         if (delayed) {
0263             if (virtqueue_enable_cb_delayed(vq->vq))
0264                 wait_for_interrupt(dev);
0265         } else {
0266             if (virtqueue_enable_cb(vq->vq))
0267                 wait_for_interrupt(dev);
0268         }
0269     }
0270     test = 0;
0271     r = ioctl(dev->control, VHOST_TEST_RUN, &test);
0272     assert(r >= 0);
0273     fprintf(stderr,
0274         "spurious wakeups: 0x%llx started=0x%lx completed=0x%lx\n",
0275         spurious, started, completed);
0276 }
0277 
0278 const char optstring[] = "h";
0279 const struct option longopts[] = {
0280     {
0281         .name = "help",
0282         .val = 'h',
0283     },
0284     {
0285         .name = "event-idx",
0286         .val = 'E',
0287     },
0288     {
0289         .name = "no-event-idx",
0290         .val = 'e',
0291     },
0292     {
0293         .name = "indirect",
0294         .val = 'I',
0295     },
0296     {
0297         .name = "no-indirect",
0298         .val = 'i',
0299     },
0300     {
0301         .name = "virtio-1",
0302         .val = '1',
0303     },
0304     {
0305         .name = "no-virtio-1",
0306         .val = '0',
0307     },
0308     {
0309         .name = "delayed-interrupt",
0310         .val = 'D',
0311     },
0312     {
0313         .name = "no-delayed-interrupt",
0314         .val = 'd',
0315     },
0316     {
0317         .name = "batch",
0318         .val = 'b',
0319         .has_arg = required_argument,
0320     },
0321     {
0322         .name = "reset",
0323         .val = 'r',
0324         .has_arg = optional_argument,
0325     },
0326     {
0327     }
0328 };
0329 
0330 static void help(void)
0331 {
0332     fprintf(stderr, "Usage: virtio_test [--help]"
0333         " [--no-indirect]"
0334         " [--no-event-idx]"
0335         " [--no-virtio-1]"
0336         " [--delayed-interrupt]"
0337         " [--batch=random/N]"
0338         " [--reset=N]"
0339         "\n");
0340 }
0341 
0342 int main(int argc, char **argv)
0343 {
0344     struct vdev_info dev;
0345     unsigned long long features = (1ULL << VIRTIO_RING_F_INDIRECT_DESC) |
0346         (1ULL << VIRTIO_RING_F_EVENT_IDX) | (1ULL << VIRTIO_F_VERSION_1);
0347     long batch = 1, reset = 0;
0348     int o;
0349     bool delayed = false;
0350 
0351     for (;;) {
0352         o = getopt_long(argc, argv, optstring, longopts, NULL);
0353         switch (o) {
0354         case -1:
0355             goto done;
0356         case '?':
0357             help();
0358             exit(2);
0359         case 'e':
0360             features &= ~(1ULL << VIRTIO_RING_F_EVENT_IDX);
0361             break;
0362         case 'h':
0363             help();
0364             goto done;
0365         case 'i':
0366             features &= ~(1ULL << VIRTIO_RING_F_INDIRECT_DESC);
0367             break;
0368         case '0':
0369             features &= ~(1ULL << VIRTIO_F_VERSION_1);
0370             break;
0371         case 'D':
0372             delayed = true;
0373             break;
0374         case 'b':
0375             if (0 == strcmp(optarg, "random")) {
0376                 batch = RANDOM_BATCH;
0377             } else {
0378                 batch = strtol(optarg, NULL, 10);
0379                 assert(batch > 0);
0380                 assert(batch < (long)INT_MAX + 1);
0381             }
0382             break;
0383         case 'r':
0384             if (!optarg) {
0385                 reset = 1;
0386             } else {
0387                 reset = strtol(optarg, NULL, 10);
0388                 assert(reset > 0);
0389                 assert(reset < (long)INT_MAX + 1);
0390             }
0391             break;
0392         default:
0393             assert(0);
0394             break;
0395         }
0396     }
0397 
0398 done:
0399     vdev_info_init(&dev, features);
0400     vq_info_add(&dev, 256);
0401     run_test(&dev, &dev.vqs[0], delayed, batch, reset, 0x100000);
0402     return 0;
0403 }