0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <linux/miscdevice.h>
0010 #include <linux/atomic.h>
0011 #include <linux/module.h>
0012 #include <linux/mutex.h>
0013 #include <linux/vmalloc.h>
0014 #include <net/sock.h>
0015 #include <linux/virtio_vsock.h>
0016 #include <linux/vhost.h>
0017 #include <linux/hashtable.h>
0018
0019 #include <net/af_vsock.h>
0020 #include "vhost.h"
0021
0022 #define VHOST_VSOCK_DEFAULT_HOST_CID 2
0023
0024
0025 #define VHOST_VSOCK_WEIGHT 0x80000
0026
0027
0028
0029
0030 #define VHOST_VSOCK_PKT_WEIGHT 256
0031
0032 enum {
0033 VHOST_VSOCK_FEATURES = VHOST_FEATURES |
0034 (1ULL << VIRTIO_F_ACCESS_PLATFORM) |
0035 (1ULL << VIRTIO_VSOCK_F_SEQPACKET)
0036 };
0037
0038 enum {
0039 VHOST_VSOCK_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
0040 };
0041
0042
0043 static DEFINE_MUTEX(vhost_vsock_mutex);
0044 static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
0045
0046 struct vhost_vsock {
0047 struct vhost_dev dev;
0048 struct vhost_virtqueue vqs[2];
0049
0050
0051 struct hlist_node hash;
0052
0053 struct vhost_work send_pkt_work;
0054 spinlock_t send_pkt_list_lock;
0055 struct list_head send_pkt_list;
0056
0057 atomic_t queued_replies;
0058
0059 u32 guest_cid;
0060 bool seqpacket_allow;
0061 };
0062
0063 static u32 vhost_transport_get_local_cid(void)
0064 {
0065 return VHOST_VSOCK_DEFAULT_HOST_CID;
0066 }
0067
0068
0069
0070
0071 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
0072 {
0073 struct vhost_vsock *vsock;
0074
0075 hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
0076 u32 other_cid = vsock->guest_cid;
0077
0078
0079 if (other_cid == 0)
0080 continue;
0081
0082 if (other_cid == guest_cid)
0083 return vsock;
0084
0085 }
0086
0087 return NULL;
0088 }
0089
0090 static void
0091 vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
0092 struct vhost_virtqueue *vq)
0093 {
0094 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
0095 int pkts = 0, total_len = 0;
0096 bool added = false;
0097 bool restart_tx = false;
0098
0099 mutex_lock(&vq->mutex);
0100
0101 if (!vhost_vq_get_backend(vq))
0102 goto out;
0103
0104 if (!vq_meta_prefetch(vq))
0105 goto out;
0106
0107
0108 vhost_disable_notify(&vsock->dev, vq);
0109
0110 do {
0111 struct virtio_vsock_pkt *pkt;
0112 struct iov_iter iov_iter;
0113 unsigned out, in;
0114 size_t nbytes;
0115 size_t iov_len, payload_len;
0116 int head;
0117 u32 flags_to_restore = 0;
0118
0119 spin_lock_bh(&vsock->send_pkt_list_lock);
0120 if (list_empty(&vsock->send_pkt_list)) {
0121 spin_unlock_bh(&vsock->send_pkt_list_lock);
0122 vhost_enable_notify(&vsock->dev, vq);
0123 break;
0124 }
0125
0126 pkt = list_first_entry(&vsock->send_pkt_list,
0127 struct virtio_vsock_pkt, list);
0128 list_del_init(&pkt->list);
0129 spin_unlock_bh(&vsock->send_pkt_list_lock);
0130
0131 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
0132 &out, &in, NULL, NULL);
0133 if (head < 0) {
0134 spin_lock_bh(&vsock->send_pkt_list_lock);
0135 list_add(&pkt->list, &vsock->send_pkt_list);
0136 spin_unlock_bh(&vsock->send_pkt_list_lock);
0137 break;
0138 }
0139
0140 if (head == vq->num) {
0141 spin_lock_bh(&vsock->send_pkt_list_lock);
0142 list_add(&pkt->list, &vsock->send_pkt_list);
0143 spin_unlock_bh(&vsock->send_pkt_list_lock);
0144
0145
0146
0147
0148 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
0149 vhost_disable_notify(&vsock->dev, vq);
0150 continue;
0151 }
0152 break;
0153 }
0154
0155 if (out) {
0156 virtio_transport_free_pkt(pkt);
0157 vq_err(vq, "Expected 0 output buffers, got %u\n", out);
0158 break;
0159 }
0160
0161 iov_len = iov_length(&vq->iov[out], in);
0162 if (iov_len < sizeof(pkt->hdr)) {
0163 virtio_transport_free_pkt(pkt);
0164 vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
0165 break;
0166 }
0167
0168 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
0169 payload_len = pkt->len - pkt->off;
0170
0171
0172
0173
0174 if (payload_len > iov_len - sizeof(pkt->hdr)) {
0175 payload_len = iov_len - sizeof(pkt->hdr);
0176
0177
0178
0179
0180
0181
0182
0183
0184
0185
0186
0187
0188 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) {
0189 pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
0190 flags_to_restore |= VIRTIO_VSOCK_SEQ_EOM;
0191
0192 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
0193 pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
0194 flags_to_restore |= VIRTIO_VSOCK_SEQ_EOR;
0195 }
0196 }
0197 }
0198
0199
0200 pkt->hdr.len = cpu_to_le32(payload_len);
0201
0202 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
0203 if (nbytes != sizeof(pkt->hdr)) {
0204 virtio_transport_free_pkt(pkt);
0205 vq_err(vq, "Faulted on copying pkt hdr\n");
0206 break;
0207 }
0208
0209 nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
0210 &iov_iter);
0211 if (nbytes != payload_len) {
0212 virtio_transport_free_pkt(pkt);
0213 vq_err(vq, "Faulted on copying pkt buf\n");
0214 break;
0215 }
0216
0217
0218
0219
0220 virtio_transport_deliver_tap_pkt(pkt);
0221
0222 vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
0223 added = true;
0224
0225 pkt->off += payload_len;
0226 total_len += payload_len;
0227
0228
0229
0230
0231 if (pkt->off < pkt->len) {
0232 pkt->hdr.flags |= cpu_to_le32(flags_to_restore);
0233
0234
0235
0236
0237
0238 pkt->tap_delivered = false;
0239
0240 spin_lock_bh(&vsock->send_pkt_list_lock);
0241 list_add(&pkt->list, &vsock->send_pkt_list);
0242 spin_unlock_bh(&vsock->send_pkt_list_lock);
0243 } else {
0244 if (pkt->reply) {
0245 int val;
0246
0247 val = atomic_dec_return(&vsock->queued_replies);
0248
0249
0250
0251
0252 if (val + 1 == tx_vq->num)
0253 restart_tx = true;
0254 }
0255
0256 virtio_transport_free_pkt(pkt);
0257 }
0258 } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
0259 if (added)
0260 vhost_signal(&vsock->dev, vq);
0261
0262 out:
0263 mutex_unlock(&vq->mutex);
0264
0265 if (restart_tx)
0266 vhost_poll_queue(&tx_vq->poll);
0267 }
0268
0269 static void vhost_transport_send_pkt_work(struct vhost_work *work)
0270 {
0271 struct vhost_virtqueue *vq;
0272 struct vhost_vsock *vsock;
0273
0274 vsock = container_of(work, struct vhost_vsock, send_pkt_work);
0275 vq = &vsock->vqs[VSOCK_VQ_RX];
0276
0277 vhost_transport_do_send_pkt(vsock, vq);
0278 }
0279
0280 static int
0281 vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
0282 {
0283 struct vhost_vsock *vsock;
0284 int len = pkt->len;
0285
0286 rcu_read_lock();
0287
0288
0289 vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
0290 if (!vsock) {
0291 rcu_read_unlock();
0292 virtio_transport_free_pkt(pkt);
0293 return -ENODEV;
0294 }
0295
0296 if (pkt->reply)
0297 atomic_inc(&vsock->queued_replies);
0298
0299 spin_lock_bh(&vsock->send_pkt_list_lock);
0300 list_add_tail(&pkt->list, &vsock->send_pkt_list);
0301 spin_unlock_bh(&vsock->send_pkt_list_lock);
0302
0303 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
0304
0305 rcu_read_unlock();
0306 return len;
0307 }
0308
0309 static int
0310 vhost_transport_cancel_pkt(struct vsock_sock *vsk)
0311 {
0312 struct vhost_vsock *vsock;
0313 struct virtio_vsock_pkt *pkt, *n;
0314 int cnt = 0;
0315 int ret = -ENODEV;
0316 LIST_HEAD(freeme);
0317
0318 rcu_read_lock();
0319
0320
0321 vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
0322 if (!vsock)
0323 goto out;
0324
0325 spin_lock_bh(&vsock->send_pkt_list_lock);
0326 list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
0327 if (pkt->vsk != vsk)
0328 continue;
0329 list_move(&pkt->list, &freeme);
0330 }
0331 spin_unlock_bh(&vsock->send_pkt_list_lock);
0332
0333 list_for_each_entry_safe(pkt, n, &freeme, list) {
0334 if (pkt->reply)
0335 cnt++;
0336 list_del(&pkt->list);
0337 virtio_transport_free_pkt(pkt);
0338 }
0339
0340 if (cnt) {
0341 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
0342 int new_cnt;
0343
0344 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
0345 if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
0346 vhost_poll_queue(&tx_vq->poll);
0347 }
0348
0349 ret = 0;
0350 out:
0351 rcu_read_unlock();
0352 return ret;
0353 }
0354
0355 static struct virtio_vsock_pkt *
0356 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
0357 unsigned int out, unsigned int in)
0358 {
0359 struct virtio_vsock_pkt *pkt;
0360 struct iov_iter iov_iter;
0361 size_t nbytes;
0362 size_t len;
0363
0364 if (in != 0) {
0365 vq_err(vq, "Expected 0 input buffers, got %u\n", in);
0366 return NULL;
0367 }
0368
0369 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
0370 if (!pkt)
0371 return NULL;
0372
0373 len = iov_length(vq->iov, out);
0374 iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
0375
0376 nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
0377 if (nbytes != sizeof(pkt->hdr)) {
0378 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
0379 sizeof(pkt->hdr), nbytes);
0380 kfree(pkt);
0381 return NULL;
0382 }
0383
0384 pkt->len = le32_to_cpu(pkt->hdr.len);
0385
0386
0387 if (!pkt->len)
0388 return pkt;
0389
0390
0391 if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
0392 kfree(pkt);
0393 return NULL;
0394 }
0395
0396 pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
0397 if (!pkt->buf) {
0398 kfree(pkt);
0399 return NULL;
0400 }
0401
0402 pkt->buf_len = pkt->len;
0403
0404 nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
0405 if (nbytes != pkt->len) {
0406 vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
0407 pkt->len, nbytes);
0408 virtio_transport_free_pkt(pkt);
0409 return NULL;
0410 }
0411
0412 return pkt;
0413 }
0414
0415
0416 static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
0417 {
0418 struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
0419 int val;
0420
0421 smp_rmb();
0422 val = atomic_read(&vsock->queued_replies);
0423
0424 return val < vq->num;
0425 }
0426
0427 static bool vhost_transport_seqpacket_allow(u32 remote_cid);
0428
0429 static struct virtio_transport vhost_transport = {
0430 .transport = {
0431 .module = THIS_MODULE,
0432
0433 .get_local_cid = vhost_transport_get_local_cid,
0434
0435 .init = virtio_transport_do_socket_init,
0436 .destruct = virtio_transport_destruct,
0437 .release = virtio_transport_release,
0438 .connect = virtio_transport_connect,
0439 .shutdown = virtio_transport_shutdown,
0440 .cancel_pkt = vhost_transport_cancel_pkt,
0441
0442 .dgram_enqueue = virtio_transport_dgram_enqueue,
0443 .dgram_dequeue = virtio_transport_dgram_dequeue,
0444 .dgram_bind = virtio_transport_dgram_bind,
0445 .dgram_allow = virtio_transport_dgram_allow,
0446
0447 .stream_enqueue = virtio_transport_stream_enqueue,
0448 .stream_dequeue = virtio_transport_stream_dequeue,
0449 .stream_has_data = virtio_transport_stream_has_data,
0450 .stream_has_space = virtio_transport_stream_has_space,
0451 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
0452 .stream_is_active = virtio_transport_stream_is_active,
0453 .stream_allow = virtio_transport_stream_allow,
0454
0455 .seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
0456 .seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
0457 .seqpacket_allow = vhost_transport_seqpacket_allow,
0458 .seqpacket_has_data = virtio_transport_seqpacket_has_data,
0459
0460 .notify_poll_in = virtio_transport_notify_poll_in,
0461 .notify_poll_out = virtio_transport_notify_poll_out,
0462 .notify_recv_init = virtio_transport_notify_recv_init,
0463 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
0464 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
0465 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
0466 .notify_send_init = virtio_transport_notify_send_init,
0467 .notify_send_pre_block = virtio_transport_notify_send_pre_block,
0468 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
0469 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
0470 .notify_buffer_size = virtio_transport_notify_buffer_size,
0471
0472 },
0473
0474 .send_pkt = vhost_transport_send_pkt,
0475 };
0476
0477 static bool vhost_transport_seqpacket_allow(u32 remote_cid)
0478 {
0479 struct vhost_vsock *vsock;
0480 bool seqpacket_allow = false;
0481
0482 rcu_read_lock();
0483 vsock = vhost_vsock_get(remote_cid);
0484
0485 if (vsock)
0486 seqpacket_allow = vsock->seqpacket_allow;
0487
0488 rcu_read_unlock();
0489
0490 return seqpacket_allow;
0491 }
0492
0493 static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
0494 {
0495 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
0496 poll.work);
0497 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
0498 dev);
0499 struct virtio_vsock_pkt *pkt;
0500 int head, pkts = 0, total_len = 0;
0501 unsigned int out, in;
0502 bool added = false;
0503
0504 mutex_lock(&vq->mutex);
0505
0506 if (!vhost_vq_get_backend(vq))
0507 goto out;
0508
0509 if (!vq_meta_prefetch(vq))
0510 goto out;
0511
0512 vhost_disable_notify(&vsock->dev, vq);
0513 do {
0514 if (!vhost_vsock_more_replies(vsock)) {
0515
0516
0517
0518
0519 goto no_more_replies;
0520 }
0521
0522 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
0523 &out, &in, NULL, NULL);
0524 if (head < 0)
0525 break;
0526
0527 if (head == vq->num) {
0528 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
0529 vhost_disable_notify(&vsock->dev, vq);
0530 continue;
0531 }
0532 break;
0533 }
0534
0535 pkt = vhost_vsock_alloc_pkt(vq, out, in);
0536 if (!pkt) {
0537 vq_err(vq, "Faulted on pkt\n");
0538 continue;
0539 }
0540
0541 total_len += sizeof(pkt->hdr) + pkt->len;
0542
0543
0544 virtio_transport_deliver_tap_pkt(pkt);
0545
0546
0547 if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
0548 le64_to_cpu(pkt->hdr.dst_cid) ==
0549 vhost_transport_get_local_cid())
0550 virtio_transport_recv_pkt(&vhost_transport, pkt);
0551 else
0552 virtio_transport_free_pkt(pkt);
0553
0554 vhost_add_used(vq, head, 0);
0555 added = true;
0556 } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
0557
0558 no_more_replies:
0559 if (added)
0560 vhost_signal(&vsock->dev, vq);
0561
0562 out:
0563 mutex_unlock(&vq->mutex);
0564 }
0565
0566 static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
0567 {
0568 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
0569 poll.work);
0570 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
0571 dev);
0572
0573 vhost_transport_do_send_pkt(vsock, vq);
0574 }
0575
0576 static int vhost_vsock_start(struct vhost_vsock *vsock)
0577 {
0578 struct vhost_virtqueue *vq;
0579 size_t i;
0580 int ret;
0581
0582 mutex_lock(&vsock->dev.mutex);
0583
0584 ret = vhost_dev_check_owner(&vsock->dev);
0585 if (ret)
0586 goto err;
0587
0588 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0589 vq = &vsock->vqs[i];
0590
0591 mutex_lock(&vq->mutex);
0592
0593 if (!vhost_vq_access_ok(vq)) {
0594 ret = -EFAULT;
0595 goto err_vq;
0596 }
0597
0598 if (!vhost_vq_get_backend(vq)) {
0599 vhost_vq_set_backend(vq, vsock);
0600 ret = vhost_vq_init_access(vq);
0601 if (ret)
0602 goto err_vq;
0603 }
0604
0605 mutex_unlock(&vq->mutex);
0606 }
0607
0608
0609
0610
0611 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
0612
0613 mutex_unlock(&vsock->dev.mutex);
0614 return 0;
0615
0616 err_vq:
0617 vhost_vq_set_backend(vq, NULL);
0618 mutex_unlock(&vq->mutex);
0619
0620 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0621 vq = &vsock->vqs[i];
0622
0623 mutex_lock(&vq->mutex);
0624 vhost_vq_set_backend(vq, NULL);
0625 mutex_unlock(&vq->mutex);
0626 }
0627 err:
0628 mutex_unlock(&vsock->dev.mutex);
0629 return ret;
0630 }
0631
0632 static int vhost_vsock_stop(struct vhost_vsock *vsock, bool check_owner)
0633 {
0634 size_t i;
0635 int ret = 0;
0636
0637 mutex_lock(&vsock->dev.mutex);
0638
0639 if (check_owner) {
0640 ret = vhost_dev_check_owner(&vsock->dev);
0641 if (ret)
0642 goto err;
0643 }
0644
0645 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0646 struct vhost_virtqueue *vq = &vsock->vqs[i];
0647
0648 mutex_lock(&vq->mutex);
0649 vhost_vq_set_backend(vq, NULL);
0650 mutex_unlock(&vq->mutex);
0651 }
0652
0653 err:
0654 mutex_unlock(&vsock->dev.mutex);
0655 return ret;
0656 }
0657
0658 static void vhost_vsock_free(struct vhost_vsock *vsock)
0659 {
0660 kvfree(vsock);
0661 }
0662
0663 static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
0664 {
0665 struct vhost_virtqueue **vqs;
0666 struct vhost_vsock *vsock;
0667 int ret;
0668
0669
0670
0671
0672 vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
0673 if (!vsock)
0674 return -ENOMEM;
0675
0676 vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
0677 if (!vqs) {
0678 ret = -ENOMEM;
0679 goto out;
0680 }
0681
0682 vsock->guest_cid = 0;
0683
0684 atomic_set(&vsock->queued_replies, 0);
0685
0686 vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
0687 vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
0688 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
0689 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
0690
0691 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
0692 UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
0693 VHOST_VSOCK_WEIGHT, true, NULL);
0694
0695 file->private_data = vsock;
0696 spin_lock_init(&vsock->send_pkt_list_lock);
0697 INIT_LIST_HEAD(&vsock->send_pkt_list);
0698 vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
0699 return 0;
0700
0701 out:
0702 vhost_vsock_free(vsock);
0703 return ret;
0704 }
0705
0706 static void vhost_vsock_flush(struct vhost_vsock *vsock)
0707 {
0708 vhost_dev_flush(&vsock->dev);
0709 }
0710
0711 static void vhost_vsock_reset_orphans(struct sock *sk)
0712 {
0713 struct vsock_sock *vsk = vsock_sk(sk);
0714
0715
0716
0717
0718
0719
0720
0721 if (vhost_vsock_get(vsk->remote_addr.svm_cid))
0722 return;
0723
0724
0725
0726
0727 if (vsk->close_work_scheduled)
0728 return;
0729
0730 sock_set_flag(sk, SOCK_DONE);
0731 vsk->peer_shutdown = SHUTDOWN_MASK;
0732 sk->sk_state = SS_UNCONNECTED;
0733 sk->sk_err = ECONNRESET;
0734 sk_error_report(sk);
0735 }
0736
0737 static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
0738 {
0739 struct vhost_vsock *vsock = file->private_data;
0740
0741 mutex_lock(&vhost_vsock_mutex);
0742 if (vsock->guest_cid)
0743 hash_del_rcu(&vsock->hash);
0744 mutex_unlock(&vhost_vsock_mutex);
0745
0746
0747 synchronize_rcu();
0748
0749
0750
0751 vsock_for_each_connected_socket(&vhost_transport.transport,
0752 vhost_vsock_reset_orphans);
0753
0754
0755
0756
0757
0758
0759 vhost_vsock_stop(vsock, false);
0760 vhost_vsock_flush(vsock);
0761 vhost_dev_stop(&vsock->dev);
0762
0763 spin_lock_bh(&vsock->send_pkt_list_lock);
0764 while (!list_empty(&vsock->send_pkt_list)) {
0765 struct virtio_vsock_pkt *pkt;
0766
0767 pkt = list_first_entry(&vsock->send_pkt_list,
0768 struct virtio_vsock_pkt, list);
0769 list_del_init(&pkt->list);
0770 virtio_transport_free_pkt(pkt);
0771 }
0772 spin_unlock_bh(&vsock->send_pkt_list_lock);
0773
0774 vhost_dev_cleanup(&vsock->dev);
0775 kfree(vsock->dev.vqs);
0776 vhost_vsock_free(vsock);
0777 return 0;
0778 }
0779
0780 static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
0781 {
0782 struct vhost_vsock *other;
0783
0784
0785 if (guest_cid <= VMADDR_CID_HOST ||
0786 guest_cid == U32_MAX)
0787 return -EINVAL;
0788
0789
0790 if (guest_cid > U32_MAX)
0791 return -EINVAL;
0792
0793
0794
0795
0796 if (vsock_find_cid(guest_cid))
0797 return -EADDRINUSE;
0798
0799
0800 mutex_lock(&vhost_vsock_mutex);
0801 other = vhost_vsock_get(guest_cid);
0802 if (other && other != vsock) {
0803 mutex_unlock(&vhost_vsock_mutex);
0804 return -EADDRINUSE;
0805 }
0806
0807 if (vsock->guest_cid)
0808 hash_del_rcu(&vsock->hash);
0809
0810 vsock->guest_cid = guest_cid;
0811 hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
0812 mutex_unlock(&vhost_vsock_mutex);
0813
0814 return 0;
0815 }
0816
0817 static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
0818 {
0819 struct vhost_virtqueue *vq;
0820 int i;
0821
0822 if (features & ~VHOST_VSOCK_FEATURES)
0823 return -EOPNOTSUPP;
0824
0825 mutex_lock(&vsock->dev.mutex);
0826 if ((features & (1 << VHOST_F_LOG_ALL)) &&
0827 !vhost_log_access_ok(&vsock->dev)) {
0828 goto err;
0829 }
0830
0831 if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) {
0832 if (vhost_init_device_iotlb(&vsock->dev, true))
0833 goto err;
0834 }
0835
0836 if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
0837 vsock->seqpacket_allow = true;
0838
0839 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
0840 vq = &vsock->vqs[i];
0841 mutex_lock(&vq->mutex);
0842 vq->acked_features = features;
0843 mutex_unlock(&vq->mutex);
0844 }
0845 mutex_unlock(&vsock->dev.mutex);
0846 return 0;
0847
0848 err:
0849 mutex_unlock(&vsock->dev.mutex);
0850 return -EFAULT;
0851 }
0852
0853 static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
0854 unsigned long arg)
0855 {
0856 struct vhost_vsock *vsock = f->private_data;
0857 void __user *argp = (void __user *)arg;
0858 u64 guest_cid;
0859 u64 features;
0860 int start;
0861 int r;
0862
0863 switch (ioctl) {
0864 case VHOST_VSOCK_SET_GUEST_CID:
0865 if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
0866 return -EFAULT;
0867 return vhost_vsock_set_cid(vsock, guest_cid);
0868 case VHOST_VSOCK_SET_RUNNING:
0869 if (copy_from_user(&start, argp, sizeof(start)))
0870 return -EFAULT;
0871 if (start)
0872 return vhost_vsock_start(vsock);
0873 else
0874 return vhost_vsock_stop(vsock, true);
0875 case VHOST_GET_FEATURES:
0876 features = VHOST_VSOCK_FEATURES;
0877 if (copy_to_user(argp, &features, sizeof(features)))
0878 return -EFAULT;
0879 return 0;
0880 case VHOST_SET_FEATURES:
0881 if (copy_from_user(&features, argp, sizeof(features)))
0882 return -EFAULT;
0883 return vhost_vsock_set_features(vsock, features);
0884 case VHOST_GET_BACKEND_FEATURES:
0885 features = VHOST_VSOCK_BACKEND_FEATURES;
0886 if (copy_to_user(argp, &features, sizeof(features)))
0887 return -EFAULT;
0888 return 0;
0889 case VHOST_SET_BACKEND_FEATURES:
0890 if (copy_from_user(&features, argp, sizeof(features)))
0891 return -EFAULT;
0892 if (features & ~VHOST_VSOCK_BACKEND_FEATURES)
0893 return -EOPNOTSUPP;
0894 vhost_set_backend_features(&vsock->dev, features);
0895 return 0;
0896 default:
0897 mutex_lock(&vsock->dev.mutex);
0898 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
0899 if (r == -ENOIOCTLCMD)
0900 r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
0901 else
0902 vhost_vsock_flush(vsock);
0903 mutex_unlock(&vsock->dev.mutex);
0904 return r;
0905 }
0906 }
0907
0908 static ssize_t vhost_vsock_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
0909 {
0910 struct file *file = iocb->ki_filp;
0911 struct vhost_vsock *vsock = file->private_data;
0912 struct vhost_dev *dev = &vsock->dev;
0913 int noblock = file->f_flags & O_NONBLOCK;
0914
0915 return vhost_chr_read_iter(dev, to, noblock);
0916 }
0917
0918 static ssize_t vhost_vsock_chr_write_iter(struct kiocb *iocb,
0919 struct iov_iter *from)
0920 {
0921 struct file *file = iocb->ki_filp;
0922 struct vhost_vsock *vsock = file->private_data;
0923 struct vhost_dev *dev = &vsock->dev;
0924
0925 return vhost_chr_write_iter(dev, from);
0926 }
0927
0928 static __poll_t vhost_vsock_chr_poll(struct file *file, poll_table *wait)
0929 {
0930 struct vhost_vsock *vsock = file->private_data;
0931 struct vhost_dev *dev = &vsock->dev;
0932
0933 return vhost_chr_poll(file, dev, wait);
0934 }
0935
0936 static const struct file_operations vhost_vsock_fops = {
0937 .owner = THIS_MODULE,
0938 .open = vhost_vsock_dev_open,
0939 .release = vhost_vsock_dev_release,
0940 .llseek = noop_llseek,
0941 .unlocked_ioctl = vhost_vsock_dev_ioctl,
0942 .compat_ioctl = compat_ptr_ioctl,
0943 .read_iter = vhost_vsock_chr_read_iter,
0944 .write_iter = vhost_vsock_chr_write_iter,
0945 .poll = vhost_vsock_chr_poll,
0946 };
0947
0948 static struct miscdevice vhost_vsock_misc = {
0949 .minor = VHOST_VSOCK_MINOR,
0950 .name = "vhost-vsock",
0951 .fops = &vhost_vsock_fops,
0952 };
0953
0954 static int __init vhost_vsock_init(void)
0955 {
0956 int ret;
0957
0958 ret = vsock_core_register(&vhost_transport.transport,
0959 VSOCK_TRANSPORT_F_H2G);
0960 if (ret < 0)
0961 return ret;
0962 return misc_register(&vhost_vsock_misc);
0963 };
0964
0965 static void __exit vhost_vsock_exit(void)
0966 {
0967 misc_deregister(&vhost_vsock_misc);
0968 vsock_core_unregister(&vhost_transport.transport);
0969 };
0970
0971 module_init(vhost_vsock_init);
0972 module_exit(vhost_vsock_exit);
0973 MODULE_LICENSE("GPL v2");
0974 MODULE_AUTHOR("Asias He");
0975 MODULE_DESCRIPTION("vhost transport for vsock ");
0976 MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
0977 MODULE_ALIAS("devname:vhost-vsock");