Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * vhost transport for vsock
0004  *
0005  * Copyright (C) 2013-2015 Red Hat, Inc.
0006  * Author: Asias He <asias@redhat.com>
0007  *         Stefan Hajnoczi <stefanha@redhat.com>
0008  */
0009 #include <linux/miscdevice.h>
0010 #include <linux/atomic.h>
0011 #include <linux/module.h>
0012 #include <linux/mutex.h>
0013 #include <linux/vmalloc.h>
0014 #include <net/sock.h>
0015 #include <linux/virtio_vsock.h>
0016 #include <linux/vhost.h>
0017 #include <linux/hashtable.h>
0018 
0019 #include <net/af_vsock.h>
0020 #include "vhost.h"
0021 
0022 #define VHOST_VSOCK_DEFAULT_HOST_CID    2
0023 /* Max number of bytes transferred before requeueing the job.
0024  * Using this limit prevents one virtqueue from starving others. */
0025 #define VHOST_VSOCK_WEIGHT 0x80000
0026 /* Max number of packets transferred before requeueing the job.
0027  * Using this limit prevents one virtqueue from starving others with
0028  * small pkts.
0029  */
0030 #define VHOST_VSOCK_PKT_WEIGHT 256
0031 
0032 enum {
0033     VHOST_VSOCK_FEATURES = VHOST_FEATURES |
0034                    (1ULL << VIRTIO_F_ACCESS_PLATFORM) |
0035                    (1ULL << VIRTIO_VSOCK_F_SEQPACKET)
0036 };
0037 
0038 enum {
0039     VHOST_VSOCK_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
0040 };
0041 
0042 /* Used to track all the vhost_vsock instances on the system. */
0043 static DEFINE_MUTEX(vhost_vsock_mutex);
0044 static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
0045 
0046 struct vhost_vsock {
0047     struct vhost_dev dev;
0048     struct vhost_virtqueue vqs[2];
0049 
0050     /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
0051     struct hlist_node hash;
0052 
0053     struct vhost_work send_pkt_work;
0054     spinlock_t send_pkt_list_lock;
0055     struct list_head send_pkt_list; /* host->guest pending packets */
0056 
0057     atomic_t queued_replies;
0058 
0059     u32 guest_cid;
0060     bool seqpacket_allow;
0061 };
0062 
0063 static u32 vhost_transport_get_local_cid(void)
0064 {
0065     return VHOST_VSOCK_DEFAULT_HOST_CID;
0066 }
0067 
0068 /* Callers that dereference the return value must hold vhost_vsock_mutex or the
0069  * RCU read lock.
0070  */
0071 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
0072 {
0073     struct vhost_vsock *vsock;
0074 
0075     hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
0076         u32 other_cid = vsock->guest_cid;
0077 
0078         /* Skip instances that have no CID yet */
0079         if (other_cid == 0)
0080             continue;
0081 
0082         if (other_cid == guest_cid)
0083             return vsock;
0084 
0085     }
0086 
0087     return NULL;
0088 }
0089 
0090 static void
0091 vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
0092                 struct vhost_virtqueue *vq)
0093 {
0094     struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
0095     int pkts = 0, total_len = 0;
0096     bool added = false;
0097     bool restart_tx = false;
0098 
0099     mutex_lock(&vq->mutex);
0100 
0101     if (!vhost_vq_get_backend(vq))
0102         goto out;
0103 
0104     if (!vq_meta_prefetch(vq))
0105         goto out;
0106 
0107     /* Avoid further vmexits, we're already processing the virtqueue */
0108     vhost_disable_notify(&vsock->dev, vq);
0109 
0110     do {
0111         struct virtio_vsock_pkt *pkt;
0112         struct iov_iter iov_iter;
0113         unsigned out, in;
0114         size_t nbytes;
0115         size_t iov_len, payload_len;
0116         int head;
0117         u32 flags_to_restore = 0;
0118 
0119         spin_lock_bh(&vsock->send_pkt_list_lock);
0120         if (list_empty(&vsock->send_pkt_list)) {
0121             spin_unlock_bh(&vsock->send_pkt_list_lock);
0122             vhost_enable_notify(&vsock->dev, vq);
0123             break;
0124         }
0125 
0126         pkt = list_first_entry(&vsock->send_pkt_list,
0127                        struct virtio_vsock_pkt, list);
0128         list_del_init(&pkt->list);
0129         spin_unlock_bh(&vsock->send_pkt_list_lock);
0130 
0131         head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
0132                      &out, &in, NULL, NULL);
0133         if (head < 0) {
0134             spin_lock_bh(&vsock->send_pkt_list_lock);
0135             list_add(&pkt->list, &vsock->send_pkt_list);
0136             spin_unlock_bh(&vsock->send_pkt_list_lock);
0137             break;
0138         }
0139 
0140         if (head == vq->num) {
0141             spin_lock_bh(&vsock->send_pkt_list_lock);
0142             list_add(&pkt->list, &vsock->send_pkt_list);
0143             spin_unlock_bh(&vsock->send_pkt_list_lock);
0144 
0145             /* We cannot finish yet if more buffers snuck in while
0146              * re-enabling notify.
0147              */
0148             if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
0149                 vhost_disable_notify(&vsock->dev, vq);
0150                 continue;
0151             }
0152             break;
0153         }
0154 
0155         if (out) {
0156             virtio_transport_free_pkt(pkt);
0157             vq_err(vq, "Expected 0 output buffers, got %u\n", out);
0158             break;
0159         }
0160 
0161         iov_len = iov_length(&vq->iov[out], in);
0162         if (iov_len < sizeof(pkt->hdr)) {
0163             virtio_transport_free_pkt(pkt);
0164             vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
0165             break;
0166         }
0167 
0168         iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
0169         payload_len = pkt->len - pkt->off;
0170 
0171         /* If the packet is greater than the space available in the
0172          * buffer, we split it using multiple buffers.
0173          */
0174         if (payload_len > iov_len - sizeof(pkt->hdr)) {
0175             payload_len = iov_len - sizeof(pkt->hdr);
0176 
0177             /* As we are copying pieces of large packet's buffer to
0178              * small rx buffers, headers of packets in rx queue are
0179              * created dynamically and are initialized with header
0180              * of current packet(except length). But in case of
0181              * SOCK_SEQPACKET, we also must clear message delimeter
0182              * bit (VIRTIO_VSOCK_SEQ_EOM) and MSG_EOR bit
0183              * (VIRTIO_VSOCK_SEQ_EOR) if set. Otherwise,
0184              * there will be sequence of packets with these
0185              * bits set. After initialized header will be copied to
0186              * rx buffer, these required bits will be restored.
0187              */
0188             if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) {
0189                 pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
0190                 flags_to_restore |= VIRTIO_VSOCK_SEQ_EOM;
0191 
0192                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
0193                     pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
0194                     flags_to_restore |= VIRTIO_VSOCK_SEQ_EOR;
0195                 }
0196             }
0197         }
0198 
0199         /* Set the correct length in the header */
0200         pkt->hdr.len = cpu_to_le32(payload_len);
0201 
0202         nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
0203         if (nbytes != sizeof(pkt->hdr)) {
0204             virtio_transport_free_pkt(pkt);
0205             vq_err(vq, "Faulted on copying pkt hdr\n");
0206             break;
0207         }
0208 
0209         nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
0210                       &iov_iter);
0211         if (nbytes != payload_len) {
0212             virtio_transport_free_pkt(pkt);
0213             vq_err(vq, "Faulted on copying pkt buf\n");
0214             break;
0215         }
0216 
0217         /* Deliver to monitoring devices all packets that we
0218          * will transmit.
0219          */
0220         virtio_transport_deliver_tap_pkt(pkt);
0221 
0222         vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
0223         added = true;
0224 
0225         pkt->off += payload_len;
0226         total_len += payload_len;
0227 
0228         /* If we didn't send all the payload we can requeue the packet
0229          * to send it with the next available buffer.
0230          */
0231         if (pkt->off < pkt->len) {
0232             pkt->hdr.flags |= cpu_to_le32(flags_to_restore);
0233 
0234             /* We are queueing the same virtio_vsock_pkt to handle
0235              * the remaining bytes, and we want to deliver it
0236              * to monitoring devices in the next iteration.
0237              */
0238             pkt->tap_delivered = false;
0239 
0240             spin_lock_bh(&vsock->send_pkt_list_lock);
0241             list_add(&pkt->list, &vsock->send_pkt_list);
0242             spin_unlock_bh(&vsock->send_pkt_list_lock);
0243         } else {
0244             if (pkt->reply) {
0245                 int val;
0246 
0247                 val = atomic_dec_return(&vsock->queued_replies);
0248 
0249                 /* Do we have resources to resume tx
0250                  * processing?
0251                  */
0252                 if (val + 1 == tx_vq->num)
0253                     restart_tx = true;
0254             }
0255 
0256             virtio_transport_free_pkt(pkt);
0257         }
0258     } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
0259     if (added)
0260         vhost_signal(&vsock->dev, vq);
0261 
0262 out:
0263     mutex_unlock(&vq->mutex);
0264 
0265     if (restart_tx)
0266         vhost_poll_queue(&tx_vq->poll);
0267 }
0268 
0269 static void vhost_transport_send_pkt_work(struct vhost_work *work)
0270 {
0271     struct vhost_virtqueue *vq;
0272     struct vhost_vsock *vsock;
0273 
0274     vsock = container_of(work, struct vhost_vsock, send_pkt_work);
0275     vq = &vsock->vqs[VSOCK_VQ_RX];
0276 
0277     vhost_transport_do_send_pkt(vsock, vq);
0278 }
0279 
0280 static int
0281 vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
0282 {
0283     struct vhost_vsock *vsock;
0284     int len = pkt->len;
0285 
0286     rcu_read_lock();
0287 
0288     /* Find the vhost_vsock according to guest context id  */
0289     vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
0290     if (!vsock) {
0291         rcu_read_unlock();
0292         virtio_transport_free_pkt(pkt);
0293         return -ENODEV;
0294     }
0295 
0296     if (pkt->reply)
0297         atomic_inc(&vsock->queued_replies);
0298 
0299     spin_lock_bh(&vsock->send_pkt_list_lock);
0300     list_add_tail(&pkt->list, &vsock->send_pkt_list);
0301     spin_unlock_bh(&vsock->send_pkt_list_lock);
0302 
0303     vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
0304 
0305     rcu_read_unlock();
0306     return len;
0307 }
0308 
0309 static int
0310 vhost_transport_cancel_pkt(struct vsock_sock *vsk)
0311 {
0312     struct vhost_vsock *vsock;
0313     struct virtio_vsock_pkt *pkt, *n;
0314     int cnt = 0;
0315     int ret = -ENODEV;
0316     LIST_HEAD(freeme);
0317 
0318     rcu_read_lock();
0319 
0320     /* Find the vhost_vsock according to guest context id  */
0321     vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
0322     if (!vsock)
0323         goto out;
0324 
0325     spin_lock_bh(&vsock->send_pkt_list_lock);
0326     list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
0327         if (pkt->vsk != vsk)
0328             continue;
0329         list_move(&pkt->list, &freeme);
0330     }
0331     spin_unlock_bh(&vsock->send_pkt_list_lock);
0332 
0333     list_for_each_entry_safe(pkt, n, &freeme, list) {
0334         if (pkt->reply)
0335             cnt++;
0336         list_del(&pkt->list);
0337         virtio_transport_free_pkt(pkt);
0338     }
0339 
0340     if (cnt) {
0341         struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
0342         int new_cnt;
0343 
0344         new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
0345         if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
0346             vhost_poll_queue(&tx_vq->poll);
0347     }
0348 
0349     ret = 0;
0350 out:
0351     rcu_read_unlock();
0352     return ret;
0353 }
0354 
0355 static struct virtio_vsock_pkt *
0356 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
0357               unsigned int out, unsigned int in)
0358 {
0359     struct virtio_vsock_pkt *pkt;
0360     struct iov_iter iov_iter;
0361     size_t nbytes;
0362     size_t len;
0363 
0364     if (in != 0) {
0365         vq_err(vq, "Expected 0 input buffers, got %u\n", in);
0366         return NULL;
0367     }
0368 
0369     pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
0370     if (!pkt)
0371         return NULL;
0372 
0373     len = iov_length(vq->iov, out);
0374     iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
0375 
0376     nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
0377     if (nbytes != sizeof(pkt->hdr)) {
0378         vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
0379                sizeof(pkt->hdr), nbytes);
0380         kfree(pkt);
0381         return NULL;
0382     }
0383 
0384     pkt->len = le32_to_cpu(pkt->hdr.len);
0385 
0386     /* No payload */
0387     if (!pkt->len)
0388         return pkt;
0389 
0390     /* The pkt is too big */
0391     if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
0392         kfree(pkt);
0393         return NULL;
0394     }
0395 
0396     pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
0397     if (!pkt->buf) {
0398         kfree(pkt);
0399         return NULL;
0400     }
0401 
0402     pkt->buf_len = pkt->len;
0403 
0404     nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
0405     if (nbytes != pkt->len) {
0406         vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
0407                pkt->len, nbytes);
0408         virtio_transport_free_pkt(pkt);
0409         return NULL;
0410     }
0411 
0412     return pkt;
0413 }
0414 
0415 /* Is there space left for replies to rx packets? */
0416 static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
0417 {
0418     struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
0419     int val;
0420 
0421     smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
0422     val = atomic_read(&vsock->queued_replies);
0423 
0424     return val < vq->num;
0425 }
0426 
0427 static bool vhost_transport_seqpacket_allow(u32 remote_cid);
0428 
0429 static struct virtio_transport vhost_transport = {
0430     .transport = {
0431         .module                   = THIS_MODULE,
0432 
0433         .get_local_cid            = vhost_transport_get_local_cid,
0434 
0435         .init                     = virtio_transport_do_socket_init,
0436         .destruct                 = virtio_transport_destruct,
0437         .release                  = virtio_transport_release,
0438         .connect                  = virtio_transport_connect,
0439         .shutdown                 = virtio_transport_shutdown,
0440         .cancel_pkt               = vhost_transport_cancel_pkt,
0441 
0442         .dgram_enqueue            = virtio_transport_dgram_enqueue,
0443         .dgram_dequeue            = virtio_transport_dgram_dequeue,
0444         .dgram_bind               = virtio_transport_dgram_bind,
0445         .dgram_allow              = virtio_transport_dgram_allow,
0446 
0447         .stream_enqueue           = virtio_transport_stream_enqueue,
0448         .stream_dequeue           = virtio_transport_stream_dequeue,
0449         .stream_has_data          = virtio_transport_stream_has_data,
0450         .stream_has_space         = virtio_transport_stream_has_space,
0451         .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
0452         .stream_is_active         = virtio_transport_stream_is_active,
0453         .stream_allow             = virtio_transport_stream_allow,
0454 
0455         .seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
0456         .seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
0457         .seqpacket_allow          = vhost_transport_seqpacket_allow,
0458         .seqpacket_has_data       = virtio_transport_seqpacket_has_data,
0459 
0460         .notify_poll_in           = virtio_transport_notify_poll_in,
0461         .notify_poll_out          = virtio_transport_notify_poll_out,
0462         .notify_recv_init         = virtio_transport_notify_recv_init,
0463         .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
0464         .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
0465         .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
0466         .notify_send_init         = virtio_transport_notify_send_init,
0467         .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
0468         .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
0469         .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
0470         .notify_buffer_size       = virtio_transport_notify_buffer_size,
0471 
0472     },
0473 
0474     .send_pkt = vhost_transport_send_pkt,
0475 };
0476 
0477 static bool vhost_transport_seqpacket_allow(u32 remote_cid)
0478 {
0479     struct vhost_vsock *vsock;
0480     bool seqpacket_allow = false;
0481 
0482     rcu_read_lock();
0483     vsock = vhost_vsock_get(remote_cid);
0484 
0485     if (vsock)
0486         seqpacket_allow = vsock->seqpacket_allow;
0487 
0488     rcu_read_unlock();
0489 
0490     return seqpacket_allow;
0491 }
0492 
0493 static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
0494 {
0495     struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
0496                           poll.work);
0497     struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
0498                          dev);
0499     struct virtio_vsock_pkt *pkt;
0500     int head, pkts = 0, total_len = 0;
0501     unsigned int out, in;
0502     bool added = false;
0503 
0504     mutex_lock(&vq->mutex);
0505 
0506     if (!vhost_vq_get_backend(vq))
0507         goto out;
0508 
0509     if (!vq_meta_prefetch(vq))
0510         goto out;
0511 
0512     vhost_disable_notify(&vsock->dev, vq);
0513     do {
0514         if (!vhost_vsock_more_replies(vsock)) {
0515             /* Stop tx until the device processes already
0516              * pending replies.  Leave tx virtqueue
0517              * callbacks disabled.
0518              */
0519             goto no_more_replies;
0520         }
0521 
0522         head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
0523                      &out, &in, NULL, NULL);
0524         if (head < 0)
0525             break;
0526 
0527         if (head == vq->num) {
0528             if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
0529                 vhost_disable_notify(&vsock->dev, vq);
0530                 continue;
0531             }
0532             break;
0533         }
0534 
0535         pkt = vhost_vsock_alloc_pkt(vq, out, in);
0536         if (!pkt) {
0537             vq_err(vq, "Faulted on pkt\n");
0538             continue;
0539         }
0540 
0541         total_len += sizeof(pkt->hdr) + pkt->len;
0542 
0543         /* Deliver to monitoring devices all received packets */
0544         virtio_transport_deliver_tap_pkt(pkt);
0545 
0546         /* Only accept correctly addressed packets */
0547         if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
0548             le64_to_cpu(pkt->hdr.dst_cid) ==
0549             vhost_transport_get_local_cid())
0550             virtio_transport_recv_pkt(&vhost_transport, pkt);
0551         else
0552             virtio_transport_free_pkt(pkt);
0553 
0554         vhost_add_used(vq, head, 0);
0555         added = true;
0556     } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
0557 
0558 no_more_replies:
0559     if (added)
0560         vhost_signal(&vsock->dev, vq);
0561 
0562 out:
0563     mutex_unlock(&vq->mutex);
0564 }
0565 
0566 static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
0567 {
0568     struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
0569                         poll.work);
0570     struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
0571                          dev);
0572 
0573     vhost_transport_do_send_pkt(vsock, vq);
0574 }
0575 
0576 static int vhost_vsock_start(struct vhost_vsock *vsock)
0577 {
0578     struct vhost_virtqueue *vq;
0579     size_t i;
0580     int ret;
0581 
0582     mutex_lock(&vsock->dev.mutex);
0583 
0584     ret = vhost_dev_check_owner(&vsock->dev);
0585     if (ret)
0586         goto err;
0587 
0588     for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0589         vq = &vsock->vqs[i];
0590 
0591         mutex_lock(&vq->mutex);
0592 
0593         if (!vhost_vq_access_ok(vq)) {
0594             ret = -EFAULT;
0595             goto err_vq;
0596         }
0597 
0598         if (!vhost_vq_get_backend(vq)) {
0599             vhost_vq_set_backend(vq, vsock);
0600             ret = vhost_vq_init_access(vq);
0601             if (ret)
0602                 goto err_vq;
0603         }
0604 
0605         mutex_unlock(&vq->mutex);
0606     }
0607 
0608     /* Some packets may have been queued before the device was started,
0609      * let's kick the send worker to send them.
0610      */
0611     vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
0612 
0613     mutex_unlock(&vsock->dev.mutex);
0614     return 0;
0615 
0616 err_vq:
0617     vhost_vq_set_backend(vq, NULL);
0618     mutex_unlock(&vq->mutex);
0619 
0620     for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0621         vq = &vsock->vqs[i];
0622 
0623         mutex_lock(&vq->mutex);
0624         vhost_vq_set_backend(vq, NULL);
0625         mutex_unlock(&vq->mutex);
0626     }
0627 err:
0628     mutex_unlock(&vsock->dev.mutex);
0629     return ret;
0630 }
0631 
0632 static int vhost_vsock_stop(struct vhost_vsock *vsock, bool check_owner)
0633 {
0634     size_t i;
0635     int ret = 0;
0636 
0637     mutex_lock(&vsock->dev.mutex);
0638 
0639     if (check_owner) {
0640         ret = vhost_dev_check_owner(&vsock->dev);
0641         if (ret)
0642             goto err;
0643     }
0644 
0645     for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0646         struct vhost_virtqueue *vq = &vsock->vqs[i];
0647 
0648         mutex_lock(&vq->mutex);
0649         vhost_vq_set_backend(vq, NULL);
0650         mutex_unlock(&vq->mutex);
0651     }
0652 
0653 err:
0654     mutex_unlock(&vsock->dev.mutex);
0655     return ret;
0656 }
0657 
0658 static void vhost_vsock_free(struct vhost_vsock *vsock)
0659 {
0660     kvfree(vsock);
0661 }
0662 
0663 static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
0664 {
0665     struct vhost_virtqueue **vqs;
0666     struct vhost_vsock *vsock;
0667     int ret;
0668 
0669     /* This struct is large and allocation could fail, fall back to vmalloc
0670      * if there is no other way.
0671      */
0672     vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
0673     if (!vsock)
0674         return -ENOMEM;
0675 
0676     vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
0677     if (!vqs) {
0678         ret = -ENOMEM;
0679         goto out;
0680     }
0681 
0682     vsock->guest_cid = 0; /* no CID assigned yet */
0683 
0684     atomic_set(&vsock->queued_replies, 0);
0685 
0686     vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
0687     vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
0688     vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
0689     vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
0690 
0691     vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
0692                UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
0693                VHOST_VSOCK_WEIGHT, true, NULL);
0694 
0695     file->private_data = vsock;
0696     spin_lock_init(&vsock->send_pkt_list_lock);
0697     INIT_LIST_HEAD(&vsock->send_pkt_list);
0698     vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
0699     return 0;
0700 
0701 out:
0702     vhost_vsock_free(vsock);
0703     return ret;
0704 }
0705 
0706 static void vhost_vsock_flush(struct vhost_vsock *vsock)
0707 {
0708     vhost_dev_flush(&vsock->dev);
0709 }
0710 
0711 static void vhost_vsock_reset_orphans(struct sock *sk)
0712 {
0713     struct vsock_sock *vsk = vsock_sk(sk);
0714 
0715     /* vmci_transport.c doesn't take sk_lock here either.  At least we're
0716      * under vsock_table_lock so the sock cannot disappear while we're
0717      * executing.
0718      */
0719 
0720     /* If the peer is still valid, no need to reset connection */
0721     if (vhost_vsock_get(vsk->remote_addr.svm_cid))
0722         return;
0723 
0724     /* If the close timeout is pending, let it expire.  This avoids races
0725      * with the timeout callback.
0726      */
0727     if (vsk->close_work_scheduled)
0728         return;
0729 
0730     sock_set_flag(sk, SOCK_DONE);
0731     vsk->peer_shutdown = SHUTDOWN_MASK;
0732     sk->sk_state = SS_UNCONNECTED;
0733     sk->sk_err = ECONNRESET;
0734     sk_error_report(sk);
0735 }
0736 
0737 static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
0738 {
0739     struct vhost_vsock *vsock = file->private_data;
0740 
0741     mutex_lock(&vhost_vsock_mutex);
0742     if (vsock->guest_cid)
0743         hash_del_rcu(&vsock->hash);
0744     mutex_unlock(&vhost_vsock_mutex);
0745 
0746     /* Wait for other CPUs to finish using vsock */
0747     synchronize_rcu();
0748 
0749     /* Iterating over all connections for all CIDs to find orphans is
0750      * inefficient.  Room for improvement here. */
0751     vsock_for_each_connected_socket(&vhost_transport.transport,
0752                     vhost_vsock_reset_orphans);
0753 
0754     /* Don't check the owner, because we are in the release path, so we
0755      * need to stop the vsock device in any case.
0756      * vhost_vsock_stop() can not fail in this case, so we don't need to
0757      * check the return code.
0758      */
0759     vhost_vsock_stop(vsock, false);
0760     vhost_vsock_flush(vsock);
0761     vhost_dev_stop(&vsock->dev);
0762 
0763     spin_lock_bh(&vsock->send_pkt_list_lock);
0764     while (!list_empty(&vsock->send_pkt_list)) {
0765         struct virtio_vsock_pkt *pkt;
0766 
0767         pkt = list_first_entry(&vsock->send_pkt_list,
0768                 struct virtio_vsock_pkt, list);
0769         list_del_init(&pkt->list);
0770         virtio_transport_free_pkt(pkt);
0771     }
0772     spin_unlock_bh(&vsock->send_pkt_list_lock);
0773 
0774     vhost_dev_cleanup(&vsock->dev);
0775     kfree(vsock->dev.vqs);
0776     vhost_vsock_free(vsock);
0777     return 0;
0778 }
0779 
0780 static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
0781 {
0782     struct vhost_vsock *other;
0783 
0784     /* Refuse reserved CIDs */
0785     if (guest_cid <= VMADDR_CID_HOST ||
0786         guest_cid == U32_MAX)
0787         return -EINVAL;
0788 
0789     /* 64-bit CIDs are not yet supported */
0790     if (guest_cid > U32_MAX)
0791         return -EINVAL;
0792 
0793     /* Refuse if CID is assigned to the guest->host transport (i.e. nested
0794      * VM), to make the loopback work.
0795      */
0796     if (vsock_find_cid(guest_cid))
0797         return -EADDRINUSE;
0798 
0799     /* Refuse if CID is already in use */
0800     mutex_lock(&vhost_vsock_mutex);
0801     other = vhost_vsock_get(guest_cid);
0802     if (other && other != vsock) {
0803         mutex_unlock(&vhost_vsock_mutex);
0804         return -EADDRINUSE;
0805     }
0806 
0807     if (vsock->guest_cid)
0808         hash_del_rcu(&vsock->hash);
0809 
0810     vsock->guest_cid = guest_cid;
0811     hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
0812     mutex_unlock(&vhost_vsock_mutex);
0813 
0814     return 0;
0815 }
0816 
0817 static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
0818 {
0819     struct vhost_virtqueue *vq;
0820     int i;
0821 
0822     if (features & ~VHOST_VSOCK_FEATURES)
0823         return -EOPNOTSUPP;
0824 
0825     mutex_lock(&vsock->dev.mutex);
0826     if ((features & (1 << VHOST_F_LOG_ALL)) &&
0827         !vhost_log_access_ok(&vsock->dev)) {
0828         goto err;
0829     }
0830 
0831     if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) {
0832         if (vhost_init_device_iotlb(&vsock->dev, true))
0833             goto err;
0834     }
0835 
0836     if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
0837         vsock->seqpacket_allow = true;
0838 
0839     for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0840         vq = &vsock->vqs[i];
0841         mutex_lock(&vq->mutex);
0842         vq->acked_features = features;
0843         mutex_unlock(&vq->mutex);
0844     }
0845     mutex_unlock(&vsock->dev.mutex);
0846     return 0;
0847 
0848 err:
0849     mutex_unlock(&vsock->dev.mutex);
0850     return -EFAULT;
0851 }
0852 
0853 static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
0854                   unsigned long arg)
0855 {
0856     struct vhost_vsock *vsock = f->private_data;
0857     void __user *argp = (void __user *)arg;
0858     u64 guest_cid;
0859     u64 features;
0860     int start;
0861     int r;
0862 
0863     switch (ioctl) {
0864     case VHOST_VSOCK_SET_GUEST_CID:
0865         if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
0866             return -EFAULT;
0867         return vhost_vsock_set_cid(vsock, guest_cid);
0868     case VHOST_VSOCK_SET_RUNNING:
0869         if (copy_from_user(&start, argp, sizeof(start)))
0870             return -EFAULT;
0871         if (start)
0872             return vhost_vsock_start(vsock);
0873         else
0874             return vhost_vsock_stop(vsock, true);
0875     case VHOST_GET_FEATURES:
0876         features = VHOST_VSOCK_FEATURES;
0877         if (copy_to_user(argp, &features, sizeof(features)))
0878             return -EFAULT;
0879         return 0;
0880     case VHOST_SET_FEATURES:
0881         if (copy_from_user(&features, argp, sizeof(features)))
0882             return -EFAULT;
0883         return vhost_vsock_set_features(vsock, features);
0884     case VHOST_GET_BACKEND_FEATURES:
0885         features = VHOST_VSOCK_BACKEND_FEATURES;
0886         if (copy_to_user(argp, &features, sizeof(features)))
0887             return -EFAULT;
0888         return 0;
0889     case VHOST_SET_BACKEND_FEATURES:
0890         if (copy_from_user(&features, argp, sizeof(features)))
0891             return -EFAULT;
0892         if (features & ~VHOST_VSOCK_BACKEND_FEATURES)
0893             return -EOPNOTSUPP;
0894         vhost_set_backend_features(&vsock->dev, features);
0895         return 0;
0896     default:
0897         mutex_lock(&vsock->dev.mutex);
0898         r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
0899         if (r == -ENOIOCTLCMD)
0900             r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
0901         else
0902             vhost_vsock_flush(vsock);
0903         mutex_unlock(&vsock->dev.mutex);
0904         return r;
0905     }
0906 }
0907 
0908 static ssize_t vhost_vsock_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
0909 {
0910     struct file *file = iocb->ki_filp;
0911     struct vhost_vsock *vsock = file->private_data;
0912     struct vhost_dev *dev = &vsock->dev;
0913     int noblock = file->f_flags & O_NONBLOCK;
0914 
0915     return vhost_chr_read_iter(dev, to, noblock);
0916 }
0917 
0918 static ssize_t vhost_vsock_chr_write_iter(struct kiocb *iocb,
0919                     struct iov_iter *from)
0920 {
0921     struct file *file = iocb->ki_filp;
0922     struct vhost_vsock *vsock = file->private_data;
0923     struct vhost_dev *dev = &vsock->dev;
0924 
0925     return vhost_chr_write_iter(dev, from);
0926 }
0927 
0928 static __poll_t vhost_vsock_chr_poll(struct file *file, poll_table *wait)
0929 {
0930     struct vhost_vsock *vsock = file->private_data;
0931     struct vhost_dev *dev = &vsock->dev;
0932 
0933     return vhost_chr_poll(file, dev, wait);
0934 }
0935 
0936 static const struct file_operations vhost_vsock_fops = {
0937     .owner          = THIS_MODULE,
0938     .open           = vhost_vsock_dev_open,
0939     .release        = vhost_vsock_dev_release,
0940     .llseek     = noop_llseek,
0941     .unlocked_ioctl = vhost_vsock_dev_ioctl,
0942     .compat_ioctl   = compat_ptr_ioctl,
0943     .read_iter      = vhost_vsock_chr_read_iter,
0944     .write_iter     = vhost_vsock_chr_write_iter,
0945     .poll           = vhost_vsock_chr_poll,
0946 };
0947 
0948 static struct miscdevice vhost_vsock_misc = {
0949     .minor = VHOST_VSOCK_MINOR,
0950     .name = "vhost-vsock",
0951     .fops = &vhost_vsock_fops,
0952 };
0953 
0954 static int __init vhost_vsock_init(void)
0955 {
0956     int ret;
0957 
0958     ret = vsock_core_register(&vhost_transport.transport,
0959                   VSOCK_TRANSPORT_F_H2G);
0960     if (ret < 0)
0961         return ret;
0962     return misc_register(&vhost_vsock_misc);
0963 };
0964 
0965 static void __exit vhost_vsock_exit(void)
0966 {
0967     misc_deregister(&vhost_vsock_misc);
0968     vsock_core_unregister(&vhost_transport.transport);
0969 };
0970 
0971 module_init(vhost_vsock_init);
0972 module_exit(vhost_vsock_exit);
0973 MODULE_LICENSE("GPL v2");
0974 MODULE_AUTHOR("Asias He");
0975 MODULE_DESCRIPTION("vhost transport for vsock ");
0976 MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
0977 MODULE_ALIAS("devname:vhost-vsock");