Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-or-later
0002 /*
0003  * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
0004  */
0005 
0006 #include <linux/inet.h>
0007 #include <linux/kthread.h>
0008 #include <linux/list.h>
0009 #include <linux/radix-tree.h>
0010 #include <linux/module.h>
0011 #include <linux/semaphore.h>
0012 #include <linux/wait.h>
0013 #include <net/sock.h>
0014 #include <net/inet_common.h>
0015 #include <net/inet_connection_sock.h>
0016 #include <net/request_sock.h>
0017 
0018 #include <xen/events.h>
0019 #include <xen/grant_table.h>
0020 #include <xen/xen.h>
0021 #include <xen/xenbus.h>
0022 #include <xen/interface/io/pvcalls.h>
0023 
0024 #define PVCALLS_VERSIONS "1"
0025 #define MAX_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
0026 
0027 static struct pvcalls_back_global {
0028     struct list_head frontends;
0029     struct semaphore frontends_lock;
0030 } pvcalls_back_global;
0031 
0032 /*
0033  * Per-frontend data structure. It contains pointers to the command
0034  * ring, its event channel, a list of active sockets and a tree of
0035  * passive sockets.
0036  */
0037 struct pvcalls_fedata {
0038     struct list_head list;
0039     struct xenbus_device *dev;
0040     struct xen_pvcalls_sring *sring;
0041     struct xen_pvcalls_back_ring ring;
0042     int irq;
0043     struct list_head socket_mappings;
0044     struct radix_tree_root socketpass_mappings;
0045     struct semaphore socket_lock;
0046 };
0047 
0048 struct pvcalls_ioworker {
0049     struct work_struct register_work;
0050     struct workqueue_struct *wq;
0051 };
0052 
0053 struct sock_mapping {
0054     struct list_head list;
0055     struct pvcalls_fedata *fedata;
0056     struct sockpass_mapping *sockpass;
0057     struct socket *sock;
0058     uint64_t id;
0059     grant_ref_t ref;
0060     struct pvcalls_data_intf *ring;
0061     void *bytes;
0062     struct pvcalls_data data;
0063     uint32_t ring_order;
0064     int irq;
0065     atomic_t read;
0066     atomic_t write;
0067     atomic_t io;
0068     atomic_t release;
0069     atomic_t eoi;
0070     void (*saved_data_ready)(struct sock *sk);
0071     struct pvcalls_ioworker ioworker;
0072 };
0073 
0074 struct sockpass_mapping {
0075     struct list_head list;
0076     struct pvcalls_fedata *fedata;
0077     struct socket *sock;
0078     uint64_t id;
0079     struct xen_pvcalls_request reqcopy;
0080     spinlock_t copy_lock;
0081     struct workqueue_struct *wq;
0082     struct work_struct register_work;
0083     void (*saved_data_ready)(struct sock *sk);
0084 };
0085 
0086 static irqreturn_t pvcalls_back_conn_event(int irq, void *sock_map);
0087 static int pvcalls_back_release_active(struct xenbus_device *dev,
0088                        struct pvcalls_fedata *fedata,
0089                        struct sock_mapping *map);
0090 
0091 static bool pvcalls_conn_back_read(void *opaque)
0092 {
0093     struct sock_mapping *map = (struct sock_mapping *)opaque;
0094     struct msghdr msg;
0095     struct kvec vec[2];
0096     RING_IDX cons, prod, size, wanted, array_size, masked_prod, masked_cons;
0097     int32_t error;
0098     struct pvcalls_data_intf *intf = map->ring;
0099     struct pvcalls_data *data = &map->data;
0100     unsigned long flags;
0101     int ret;
0102 
0103     array_size = XEN_FLEX_RING_SIZE(map->ring_order);
0104     cons = intf->in_cons;
0105     prod = intf->in_prod;
0106     error = intf->in_error;
0107     /* read the indexes first, then deal with the data */
0108     virt_mb();
0109 
0110     if (error)
0111         return false;
0112 
0113     size = pvcalls_queued(prod, cons, array_size);
0114     if (size >= array_size)
0115         return false;
0116     spin_lock_irqsave(&map->sock->sk->sk_receive_queue.lock, flags);
0117     if (skb_queue_empty(&map->sock->sk->sk_receive_queue)) {
0118         atomic_set(&map->read, 0);
0119         spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock,
0120                 flags);
0121         return true;
0122     }
0123     spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock, flags);
0124     wanted = array_size - size;
0125     masked_prod = pvcalls_mask(prod, array_size);
0126     masked_cons = pvcalls_mask(cons, array_size);
0127 
0128     memset(&msg, 0, sizeof(msg));
0129     if (masked_prod < masked_cons) {
0130         vec[0].iov_base = data->in + masked_prod;
0131         vec[0].iov_len = wanted;
0132         iov_iter_kvec(&msg.msg_iter, WRITE, vec, 1, wanted);
0133     } else {
0134         vec[0].iov_base = data->in + masked_prod;
0135         vec[0].iov_len = array_size - masked_prod;
0136         vec[1].iov_base = data->in;
0137         vec[1].iov_len = wanted - vec[0].iov_len;
0138         iov_iter_kvec(&msg.msg_iter, WRITE, vec, 2, wanted);
0139     }
0140 
0141     atomic_set(&map->read, 0);
0142     ret = inet_recvmsg(map->sock, &msg, wanted, MSG_DONTWAIT);
0143     WARN_ON(ret > wanted);
0144     if (ret == -EAGAIN) /* shouldn't happen */
0145         return true;
0146     if (!ret)
0147         ret = -ENOTCONN;
0148     spin_lock_irqsave(&map->sock->sk->sk_receive_queue.lock, flags);
0149     if (ret > 0 && !skb_queue_empty(&map->sock->sk->sk_receive_queue))
0150         atomic_inc(&map->read);
0151     spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock, flags);
0152 
0153     /* write the data, then modify the indexes */
0154     virt_wmb();
0155     if (ret < 0) {
0156         atomic_set(&map->read, 0);
0157         intf->in_error = ret;
0158     } else
0159         intf->in_prod = prod + ret;
0160     /* update the indexes, then notify the other end */
0161     virt_wmb();
0162     notify_remote_via_irq(map->irq);
0163 
0164     return true;
0165 }
0166 
0167 static bool pvcalls_conn_back_write(struct sock_mapping *map)
0168 {
0169     struct pvcalls_data_intf *intf = map->ring;
0170     struct pvcalls_data *data = &map->data;
0171     struct msghdr msg;
0172     struct kvec vec[2];
0173     RING_IDX cons, prod, size, array_size;
0174     int ret;
0175 
0176     cons = intf->out_cons;
0177     prod = intf->out_prod;
0178     /* read the indexes before dealing with the data */
0179     virt_mb();
0180 
0181     array_size = XEN_FLEX_RING_SIZE(map->ring_order);
0182     size = pvcalls_queued(prod, cons, array_size);
0183     if (size == 0)
0184         return false;
0185 
0186     memset(&msg, 0, sizeof(msg));
0187     msg.msg_flags |= MSG_DONTWAIT;
0188     if (pvcalls_mask(prod, array_size) > pvcalls_mask(cons, array_size)) {
0189         vec[0].iov_base = data->out + pvcalls_mask(cons, array_size);
0190         vec[0].iov_len = size;
0191         iov_iter_kvec(&msg.msg_iter, READ, vec, 1, size);
0192     } else {
0193         vec[0].iov_base = data->out + pvcalls_mask(cons, array_size);
0194         vec[0].iov_len = array_size - pvcalls_mask(cons, array_size);
0195         vec[1].iov_base = data->out;
0196         vec[1].iov_len = size - vec[0].iov_len;
0197         iov_iter_kvec(&msg.msg_iter, READ, vec, 2, size);
0198     }
0199 
0200     atomic_set(&map->write, 0);
0201     ret = inet_sendmsg(map->sock, &msg, size);
0202     if (ret == -EAGAIN) {
0203         atomic_inc(&map->write);
0204         atomic_inc(&map->io);
0205         return true;
0206     }
0207 
0208     /* write the data, then update the indexes */
0209     virt_wmb();
0210     if (ret < 0) {
0211         intf->out_error = ret;
0212     } else {
0213         intf->out_error = 0;
0214         intf->out_cons = cons + ret;
0215         prod = intf->out_prod;
0216     }
0217     /* update the indexes, then notify the other end */
0218     virt_wmb();
0219     if (prod != cons + ret) {
0220         atomic_inc(&map->write);
0221         atomic_inc(&map->io);
0222     }
0223     notify_remote_via_irq(map->irq);
0224 
0225     return true;
0226 }
0227 
0228 static void pvcalls_back_ioworker(struct work_struct *work)
0229 {
0230     struct pvcalls_ioworker *ioworker = container_of(work,
0231         struct pvcalls_ioworker, register_work);
0232     struct sock_mapping *map = container_of(ioworker, struct sock_mapping,
0233         ioworker);
0234     unsigned int eoi_flags = XEN_EOI_FLAG_SPURIOUS;
0235 
0236     while (atomic_read(&map->io) > 0) {
0237         if (atomic_read(&map->release) > 0) {
0238             atomic_set(&map->release, 0);
0239             return;
0240         }
0241 
0242         if (atomic_read(&map->read) > 0 &&
0243             pvcalls_conn_back_read(map))
0244             eoi_flags = 0;
0245         if (atomic_read(&map->write) > 0 &&
0246             pvcalls_conn_back_write(map))
0247             eoi_flags = 0;
0248 
0249         if (atomic_read(&map->eoi) > 0 && !atomic_read(&map->write)) {
0250             atomic_set(&map->eoi, 0);
0251             xen_irq_lateeoi(map->irq, eoi_flags);
0252             eoi_flags = XEN_EOI_FLAG_SPURIOUS;
0253         }
0254 
0255         atomic_dec(&map->io);
0256     }
0257 }
0258 
0259 static int pvcalls_back_socket(struct xenbus_device *dev,
0260         struct xen_pvcalls_request *req)
0261 {
0262     struct pvcalls_fedata *fedata;
0263     int ret;
0264     struct xen_pvcalls_response *rsp;
0265 
0266     fedata = dev_get_drvdata(&dev->dev);
0267 
0268     if (req->u.socket.domain != AF_INET ||
0269         req->u.socket.type != SOCK_STREAM ||
0270         (req->u.socket.protocol != IPPROTO_IP &&
0271          req->u.socket.protocol != AF_INET))
0272         ret = -EAFNOSUPPORT;
0273     else
0274         ret = 0;
0275 
0276     /* leave the actual socket allocation for later */
0277 
0278     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0279     rsp->req_id = req->req_id;
0280     rsp->cmd = req->cmd;
0281     rsp->u.socket.id = req->u.socket.id;
0282     rsp->ret = ret;
0283 
0284     return 0;
0285 }
0286 
0287 static void pvcalls_sk_state_change(struct sock *sock)
0288 {
0289     struct sock_mapping *map = sock->sk_user_data;
0290 
0291     if (map == NULL)
0292         return;
0293 
0294     atomic_inc(&map->read);
0295     notify_remote_via_irq(map->irq);
0296 }
0297 
0298 static void pvcalls_sk_data_ready(struct sock *sock)
0299 {
0300     struct sock_mapping *map = sock->sk_user_data;
0301     struct pvcalls_ioworker *iow;
0302 
0303     if (map == NULL)
0304         return;
0305 
0306     iow = &map->ioworker;
0307     atomic_inc(&map->read);
0308     atomic_inc(&map->io);
0309     queue_work(iow->wq, &iow->register_work);
0310 }
0311 
0312 static struct sock_mapping *pvcalls_new_active_socket(
0313         struct pvcalls_fedata *fedata,
0314         uint64_t id,
0315         grant_ref_t ref,
0316         evtchn_port_t evtchn,
0317         struct socket *sock)
0318 {
0319     int ret;
0320     struct sock_mapping *map;
0321     void *page;
0322 
0323     map = kzalloc(sizeof(*map), GFP_KERNEL);
0324     if (map == NULL)
0325         return NULL;
0326 
0327     map->fedata = fedata;
0328     map->sock = sock;
0329     map->id = id;
0330     map->ref = ref;
0331 
0332     ret = xenbus_map_ring_valloc(fedata->dev, &ref, 1, &page);
0333     if (ret < 0)
0334         goto out;
0335     map->ring = page;
0336     map->ring_order = map->ring->ring_order;
0337     /* first read the order, then map the data ring */
0338     virt_rmb();
0339     if (map->ring_order > MAX_RING_ORDER) {
0340         pr_warn("%s frontend requested ring_order %u, which is > MAX (%u)\n",
0341                 __func__, map->ring_order, MAX_RING_ORDER);
0342         goto out;
0343     }
0344     ret = xenbus_map_ring_valloc(fedata->dev, map->ring->ref,
0345                      (1 << map->ring_order), &page);
0346     if (ret < 0)
0347         goto out;
0348     map->bytes = page;
0349 
0350     ret = bind_interdomain_evtchn_to_irqhandler_lateeoi(
0351             fedata->dev, evtchn,
0352             pvcalls_back_conn_event, 0, "pvcalls-backend", map);
0353     if (ret < 0)
0354         goto out;
0355     map->irq = ret;
0356 
0357     map->data.in = map->bytes;
0358     map->data.out = map->bytes + XEN_FLEX_RING_SIZE(map->ring_order);
0359 
0360     map->ioworker.wq = alloc_workqueue("pvcalls_io", WQ_UNBOUND, 1);
0361     if (!map->ioworker.wq)
0362         goto out;
0363     atomic_set(&map->io, 1);
0364     INIT_WORK(&map->ioworker.register_work, pvcalls_back_ioworker);
0365 
0366     down(&fedata->socket_lock);
0367     list_add_tail(&map->list, &fedata->socket_mappings);
0368     up(&fedata->socket_lock);
0369 
0370     write_lock_bh(&map->sock->sk->sk_callback_lock);
0371     map->saved_data_ready = map->sock->sk->sk_data_ready;
0372     map->sock->sk->sk_user_data = map;
0373     map->sock->sk->sk_data_ready = pvcalls_sk_data_ready;
0374     map->sock->sk->sk_state_change = pvcalls_sk_state_change;
0375     write_unlock_bh(&map->sock->sk->sk_callback_lock);
0376 
0377     return map;
0378 out:
0379     down(&fedata->socket_lock);
0380     list_del(&map->list);
0381     pvcalls_back_release_active(fedata->dev, fedata, map);
0382     up(&fedata->socket_lock);
0383     return NULL;
0384 }
0385 
0386 static int pvcalls_back_connect(struct xenbus_device *dev,
0387                 struct xen_pvcalls_request *req)
0388 {
0389     struct pvcalls_fedata *fedata;
0390     int ret = -EINVAL;
0391     struct socket *sock;
0392     struct sock_mapping *map;
0393     struct xen_pvcalls_response *rsp;
0394     struct sockaddr *sa = (struct sockaddr *)&req->u.connect.addr;
0395 
0396     fedata = dev_get_drvdata(&dev->dev);
0397 
0398     if (req->u.connect.len < sizeof(sa->sa_family) ||
0399         req->u.connect.len > sizeof(req->u.connect.addr) ||
0400         sa->sa_family != AF_INET)
0401         goto out;
0402 
0403     ret = sock_create(AF_INET, SOCK_STREAM, 0, &sock);
0404     if (ret < 0)
0405         goto out;
0406     ret = inet_stream_connect(sock, sa, req->u.connect.len, 0);
0407     if (ret < 0) {
0408         sock_release(sock);
0409         goto out;
0410     }
0411 
0412     map = pvcalls_new_active_socket(fedata,
0413                     req->u.connect.id,
0414                     req->u.connect.ref,
0415                     req->u.connect.evtchn,
0416                     sock);
0417     if (!map) {
0418         ret = -EFAULT;
0419         sock_release(sock);
0420     }
0421 
0422 out:
0423     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0424     rsp->req_id = req->req_id;
0425     rsp->cmd = req->cmd;
0426     rsp->u.connect.id = req->u.connect.id;
0427     rsp->ret = ret;
0428 
0429     return 0;
0430 }
0431 
0432 static int pvcalls_back_release_active(struct xenbus_device *dev,
0433                        struct pvcalls_fedata *fedata,
0434                        struct sock_mapping *map)
0435 {
0436     disable_irq(map->irq);
0437     if (map->sock->sk != NULL) {
0438         write_lock_bh(&map->sock->sk->sk_callback_lock);
0439         map->sock->sk->sk_user_data = NULL;
0440         map->sock->sk->sk_data_ready = map->saved_data_ready;
0441         write_unlock_bh(&map->sock->sk->sk_callback_lock);
0442     }
0443 
0444     atomic_set(&map->release, 1);
0445     flush_work(&map->ioworker.register_work);
0446 
0447     xenbus_unmap_ring_vfree(dev, map->bytes);
0448     xenbus_unmap_ring_vfree(dev, (void *)map->ring);
0449     unbind_from_irqhandler(map->irq, map);
0450 
0451     sock_release(map->sock);
0452     kfree(map);
0453 
0454     return 0;
0455 }
0456 
0457 static int pvcalls_back_release_passive(struct xenbus_device *dev,
0458                     struct pvcalls_fedata *fedata,
0459                     struct sockpass_mapping *mappass)
0460 {
0461     if (mappass->sock->sk != NULL) {
0462         write_lock_bh(&mappass->sock->sk->sk_callback_lock);
0463         mappass->sock->sk->sk_user_data = NULL;
0464         mappass->sock->sk->sk_data_ready = mappass->saved_data_ready;
0465         write_unlock_bh(&mappass->sock->sk->sk_callback_lock);
0466     }
0467     sock_release(mappass->sock);
0468     destroy_workqueue(mappass->wq);
0469     kfree(mappass);
0470 
0471     return 0;
0472 }
0473 
0474 static int pvcalls_back_release(struct xenbus_device *dev,
0475                 struct xen_pvcalls_request *req)
0476 {
0477     struct pvcalls_fedata *fedata;
0478     struct sock_mapping *map, *n;
0479     struct sockpass_mapping *mappass;
0480     int ret = 0;
0481     struct xen_pvcalls_response *rsp;
0482 
0483     fedata = dev_get_drvdata(&dev->dev);
0484 
0485     down(&fedata->socket_lock);
0486     list_for_each_entry_safe(map, n, &fedata->socket_mappings, list) {
0487         if (map->id == req->u.release.id) {
0488             list_del(&map->list);
0489             up(&fedata->socket_lock);
0490             ret = pvcalls_back_release_active(dev, fedata, map);
0491             goto out;
0492         }
0493     }
0494     mappass = radix_tree_lookup(&fedata->socketpass_mappings,
0495                     req->u.release.id);
0496     if (mappass != NULL) {
0497         radix_tree_delete(&fedata->socketpass_mappings, mappass->id);
0498         up(&fedata->socket_lock);
0499         ret = pvcalls_back_release_passive(dev, fedata, mappass);
0500     } else
0501         up(&fedata->socket_lock);
0502 
0503 out:
0504     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0505     rsp->req_id = req->req_id;
0506     rsp->u.release.id = req->u.release.id;
0507     rsp->cmd = req->cmd;
0508     rsp->ret = ret;
0509     return 0;
0510 }
0511 
0512 static void __pvcalls_back_accept(struct work_struct *work)
0513 {
0514     struct sockpass_mapping *mappass = container_of(
0515         work, struct sockpass_mapping, register_work);
0516     struct sock_mapping *map;
0517     struct pvcalls_ioworker *iow;
0518     struct pvcalls_fedata *fedata;
0519     struct socket *sock;
0520     struct xen_pvcalls_response *rsp;
0521     struct xen_pvcalls_request *req;
0522     int notify;
0523     int ret = -EINVAL;
0524     unsigned long flags;
0525 
0526     fedata = mappass->fedata;
0527     /*
0528      * __pvcalls_back_accept can race against pvcalls_back_accept.
0529      * We only need to check the value of "cmd" on read. It could be
0530      * done atomically, but to simplify the code on the write side, we
0531      * use a spinlock.
0532      */
0533     spin_lock_irqsave(&mappass->copy_lock, flags);
0534     req = &mappass->reqcopy;
0535     if (req->cmd != PVCALLS_ACCEPT) {
0536         spin_unlock_irqrestore(&mappass->copy_lock, flags);
0537         return;
0538     }
0539     spin_unlock_irqrestore(&mappass->copy_lock, flags);
0540 
0541     sock = sock_alloc();
0542     if (sock == NULL)
0543         goto out_error;
0544     sock->type = mappass->sock->type;
0545     sock->ops = mappass->sock->ops;
0546 
0547     ret = inet_accept(mappass->sock, sock, O_NONBLOCK, true);
0548     if (ret == -EAGAIN) {
0549         sock_release(sock);
0550         return;
0551     }
0552 
0553     map = pvcalls_new_active_socket(fedata,
0554                     req->u.accept.id_new,
0555                     req->u.accept.ref,
0556                     req->u.accept.evtchn,
0557                     sock);
0558     if (!map) {
0559         ret = -EFAULT;
0560         sock_release(sock);
0561         goto out_error;
0562     }
0563 
0564     map->sockpass = mappass;
0565     iow = &map->ioworker;
0566     atomic_inc(&map->read);
0567     atomic_inc(&map->io);
0568     queue_work(iow->wq, &iow->register_work);
0569 
0570 out_error:
0571     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0572     rsp->req_id = req->req_id;
0573     rsp->cmd = req->cmd;
0574     rsp->u.accept.id = req->u.accept.id;
0575     rsp->ret = ret;
0576     RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(&fedata->ring, notify);
0577     if (notify)
0578         notify_remote_via_irq(fedata->irq);
0579 
0580     mappass->reqcopy.cmd = 0;
0581 }
0582 
0583 static void pvcalls_pass_sk_data_ready(struct sock *sock)
0584 {
0585     struct sockpass_mapping *mappass = sock->sk_user_data;
0586     struct pvcalls_fedata *fedata;
0587     struct xen_pvcalls_response *rsp;
0588     unsigned long flags;
0589     int notify;
0590 
0591     if (mappass == NULL)
0592         return;
0593 
0594     fedata = mappass->fedata;
0595     spin_lock_irqsave(&mappass->copy_lock, flags);
0596     if (mappass->reqcopy.cmd == PVCALLS_POLL) {
0597         rsp = RING_GET_RESPONSE(&fedata->ring,
0598                     fedata->ring.rsp_prod_pvt++);
0599         rsp->req_id = mappass->reqcopy.req_id;
0600         rsp->u.poll.id = mappass->reqcopy.u.poll.id;
0601         rsp->cmd = mappass->reqcopy.cmd;
0602         rsp->ret = 0;
0603 
0604         mappass->reqcopy.cmd = 0;
0605         spin_unlock_irqrestore(&mappass->copy_lock, flags);
0606 
0607         RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(&fedata->ring, notify);
0608         if (notify)
0609             notify_remote_via_irq(mappass->fedata->irq);
0610     } else {
0611         spin_unlock_irqrestore(&mappass->copy_lock, flags);
0612         queue_work(mappass->wq, &mappass->register_work);
0613     }
0614 }
0615 
0616 static int pvcalls_back_bind(struct xenbus_device *dev,
0617                  struct xen_pvcalls_request *req)
0618 {
0619     struct pvcalls_fedata *fedata;
0620     int ret;
0621     struct sockpass_mapping *map;
0622     struct xen_pvcalls_response *rsp;
0623 
0624     fedata = dev_get_drvdata(&dev->dev);
0625 
0626     map = kzalloc(sizeof(*map), GFP_KERNEL);
0627     if (map == NULL) {
0628         ret = -ENOMEM;
0629         goto out;
0630     }
0631 
0632     INIT_WORK(&map->register_work, __pvcalls_back_accept);
0633     spin_lock_init(&map->copy_lock);
0634     map->wq = alloc_workqueue("pvcalls_wq", WQ_UNBOUND, 1);
0635     if (!map->wq) {
0636         ret = -ENOMEM;
0637         goto out;
0638     }
0639 
0640     ret = sock_create(AF_INET, SOCK_STREAM, 0, &map->sock);
0641     if (ret < 0)
0642         goto out;
0643 
0644     ret = inet_bind(map->sock, (struct sockaddr *)&req->u.bind.addr,
0645             req->u.bind.len);
0646     if (ret < 0)
0647         goto out;
0648 
0649     map->fedata = fedata;
0650     map->id = req->u.bind.id;
0651 
0652     down(&fedata->socket_lock);
0653     ret = radix_tree_insert(&fedata->socketpass_mappings, map->id,
0654                 map);
0655     up(&fedata->socket_lock);
0656     if (ret)
0657         goto out;
0658 
0659     write_lock_bh(&map->sock->sk->sk_callback_lock);
0660     map->saved_data_ready = map->sock->sk->sk_data_ready;
0661     map->sock->sk->sk_user_data = map;
0662     map->sock->sk->sk_data_ready = pvcalls_pass_sk_data_ready;
0663     write_unlock_bh(&map->sock->sk->sk_callback_lock);
0664 
0665 out:
0666     if (ret) {
0667         if (map && map->sock)
0668             sock_release(map->sock);
0669         if (map && map->wq)
0670             destroy_workqueue(map->wq);
0671         kfree(map);
0672     }
0673     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0674     rsp->req_id = req->req_id;
0675     rsp->cmd = req->cmd;
0676     rsp->u.bind.id = req->u.bind.id;
0677     rsp->ret = ret;
0678     return 0;
0679 }
0680 
0681 static int pvcalls_back_listen(struct xenbus_device *dev,
0682                    struct xen_pvcalls_request *req)
0683 {
0684     struct pvcalls_fedata *fedata;
0685     int ret = -EINVAL;
0686     struct sockpass_mapping *map;
0687     struct xen_pvcalls_response *rsp;
0688 
0689     fedata = dev_get_drvdata(&dev->dev);
0690 
0691     down(&fedata->socket_lock);
0692     map = radix_tree_lookup(&fedata->socketpass_mappings, req->u.listen.id);
0693     up(&fedata->socket_lock);
0694     if (map == NULL)
0695         goto out;
0696 
0697     ret = inet_listen(map->sock, req->u.listen.backlog);
0698 
0699 out:
0700     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0701     rsp->req_id = req->req_id;
0702     rsp->cmd = req->cmd;
0703     rsp->u.listen.id = req->u.listen.id;
0704     rsp->ret = ret;
0705     return 0;
0706 }
0707 
0708 static int pvcalls_back_accept(struct xenbus_device *dev,
0709                    struct xen_pvcalls_request *req)
0710 {
0711     struct pvcalls_fedata *fedata;
0712     struct sockpass_mapping *mappass;
0713     int ret = -EINVAL;
0714     struct xen_pvcalls_response *rsp;
0715     unsigned long flags;
0716 
0717     fedata = dev_get_drvdata(&dev->dev);
0718 
0719     down(&fedata->socket_lock);
0720     mappass = radix_tree_lookup(&fedata->socketpass_mappings,
0721         req->u.accept.id);
0722     up(&fedata->socket_lock);
0723     if (mappass == NULL)
0724         goto out_error;
0725 
0726     /*
0727      * Limitation of the current implementation: only support one
0728      * concurrent accept or poll call on one socket.
0729      */
0730     spin_lock_irqsave(&mappass->copy_lock, flags);
0731     if (mappass->reqcopy.cmd != 0) {
0732         spin_unlock_irqrestore(&mappass->copy_lock, flags);
0733         ret = -EINTR;
0734         goto out_error;
0735     }
0736 
0737     mappass->reqcopy = *req;
0738     spin_unlock_irqrestore(&mappass->copy_lock, flags);
0739     queue_work(mappass->wq, &mappass->register_work);
0740 
0741     /* Tell the caller we don't need to send back a notification yet */
0742     return -1;
0743 
0744 out_error:
0745     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0746     rsp->req_id = req->req_id;
0747     rsp->cmd = req->cmd;
0748     rsp->u.accept.id = req->u.accept.id;
0749     rsp->ret = ret;
0750     return 0;
0751 }
0752 
0753 static int pvcalls_back_poll(struct xenbus_device *dev,
0754                  struct xen_pvcalls_request *req)
0755 {
0756     struct pvcalls_fedata *fedata;
0757     struct sockpass_mapping *mappass;
0758     struct xen_pvcalls_response *rsp;
0759     struct inet_connection_sock *icsk;
0760     struct request_sock_queue *queue;
0761     unsigned long flags;
0762     int ret;
0763     bool data;
0764 
0765     fedata = dev_get_drvdata(&dev->dev);
0766 
0767     down(&fedata->socket_lock);
0768     mappass = radix_tree_lookup(&fedata->socketpass_mappings,
0769                     req->u.poll.id);
0770     up(&fedata->socket_lock);
0771     if (mappass == NULL)
0772         return -EINVAL;
0773 
0774     /*
0775      * Limitation of the current implementation: only support one
0776      * concurrent accept or poll call on one socket.
0777      */
0778     spin_lock_irqsave(&mappass->copy_lock, flags);
0779     if (mappass->reqcopy.cmd != 0) {
0780         ret = -EINTR;
0781         goto out;
0782     }
0783 
0784     mappass->reqcopy = *req;
0785     icsk = inet_csk(mappass->sock->sk);
0786     queue = &icsk->icsk_accept_queue;
0787     data = READ_ONCE(queue->rskq_accept_head) != NULL;
0788     if (data) {
0789         mappass->reqcopy.cmd = 0;
0790         ret = 0;
0791         goto out;
0792     }
0793     spin_unlock_irqrestore(&mappass->copy_lock, flags);
0794 
0795     /* Tell the caller we don't need to send back a notification yet */
0796     return -1;
0797 
0798 out:
0799     spin_unlock_irqrestore(&mappass->copy_lock, flags);
0800 
0801     rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
0802     rsp->req_id = req->req_id;
0803     rsp->cmd = req->cmd;
0804     rsp->u.poll.id = req->u.poll.id;
0805     rsp->ret = ret;
0806     return 0;
0807 }
0808 
0809 static int pvcalls_back_handle_cmd(struct xenbus_device *dev,
0810                    struct xen_pvcalls_request *req)
0811 {
0812     int ret = 0;
0813 
0814     switch (req->cmd) {
0815     case PVCALLS_SOCKET:
0816         ret = pvcalls_back_socket(dev, req);
0817         break;
0818     case PVCALLS_CONNECT:
0819         ret = pvcalls_back_connect(dev, req);
0820         break;
0821     case PVCALLS_RELEASE:
0822         ret = pvcalls_back_release(dev, req);
0823         break;
0824     case PVCALLS_BIND:
0825         ret = pvcalls_back_bind(dev, req);
0826         break;
0827     case PVCALLS_LISTEN:
0828         ret = pvcalls_back_listen(dev, req);
0829         break;
0830     case PVCALLS_ACCEPT:
0831         ret = pvcalls_back_accept(dev, req);
0832         break;
0833     case PVCALLS_POLL:
0834         ret = pvcalls_back_poll(dev, req);
0835         break;
0836     default:
0837     {
0838         struct pvcalls_fedata *fedata;
0839         struct xen_pvcalls_response *rsp;
0840 
0841         fedata = dev_get_drvdata(&dev->dev);
0842         rsp = RING_GET_RESPONSE(
0843                 &fedata->ring, fedata->ring.rsp_prod_pvt++);
0844         rsp->req_id = req->req_id;
0845         rsp->cmd = req->cmd;
0846         rsp->ret = -ENOTSUPP;
0847         break;
0848     }
0849     }
0850     return ret;
0851 }
0852 
0853 static void pvcalls_back_work(struct pvcalls_fedata *fedata)
0854 {
0855     int notify, notify_all = 0, more = 1;
0856     struct xen_pvcalls_request req;
0857     struct xenbus_device *dev = fedata->dev;
0858 
0859     while (more) {
0860         while (RING_HAS_UNCONSUMED_REQUESTS(&fedata->ring)) {
0861             RING_COPY_REQUEST(&fedata->ring,
0862                       fedata->ring.req_cons++,
0863                       &req);
0864 
0865             if (!pvcalls_back_handle_cmd(dev, &req)) {
0866                 RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(
0867                     &fedata->ring, notify);
0868                 notify_all += notify;
0869             }
0870         }
0871 
0872         if (notify_all) {
0873             notify_remote_via_irq(fedata->irq);
0874             notify_all = 0;
0875         }
0876 
0877         RING_FINAL_CHECK_FOR_REQUESTS(&fedata->ring, more);
0878     }
0879 }
0880 
0881 static irqreturn_t pvcalls_back_event(int irq, void *dev_id)
0882 {
0883     struct xenbus_device *dev = dev_id;
0884     struct pvcalls_fedata *fedata = NULL;
0885     unsigned int eoi_flags = XEN_EOI_FLAG_SPURIOUS;
0886 
0887     if (dev) {
0888         fedata = dev_get_drvdata(&dev->dev);
0889         if (fedata) {
0890             pvcalls_back_work(fedata);
0891             eoi_flags = 0;
0892         }
0893     }
0894 
0895     xen_irq_lateeoi(irq, eoi_flags);
0896 
0897     return IRQ_HANDLED;
0898 }
0899 
0900 static irqreturn_t pvcalls_back_conn_event(int irq, void *sock_map)
0901 {
0902     struct sock_mapping *map = sock_map;
0903     struct pvcalls_ioworker *iow;
0904 
0905     if (map == NULL || map->sock == NULL || map->sock->sk == NULL ||
0906         map->sock->sk->sk_user_data != map) {
0907         xen_irq_lateeoi(irq, 0);
0908         return IRQ_HANDLED;
0909     }
0910 
0911     iow = &map->ioworker;
0912 
0913     atomic_inc(&map->write);
0914     atomic_inc(&map->eoi);
0915     atomic_inc(&map->io);
0916     queue_work(iow->wq, &iow->register_work);
0917 
0918     return IRQ_HANDLED;
0919 }
0920 
0921 static int backend_connect(struct xenbus_device *dev)
0922 {
0923     int err;
0924     evtchn_port_t evtchn;
0925     grant_ref_t ring_ref;
0926     struct pvcalls_fedata *fedata = NULL;
0927 
0928     fedata = kzalloc(sizeof(struct pvcalls_fedata), GFP_KERNEL);
0929     if (!fedata)
0930         return -ENOMEM;
0931 
0932     fedata->irq = -1;
0933     err = xenbus_scanf(XBT_NIL, dev->otherend, "port", "%u",
0934                &evtchn);
0935     if (err != 1) {
0936         err = -EINVAL;
0937         xenbus_dev_fatal(dev, err, "reading %s/event-channel",
0938                  dev->otherend);
0939         goto error;
0940     }
0941 
0942     err = xenbus_scanf(XBT_NIL, dev->otherend, "ring-ref", "%u", &ring_ref);
0943     if (err != 1) {
0944         err = -EINVAL;
0945         xenbus_dev_fatal(dev, err, "reading %s/ring-ref",
0946                  dev->otherend);
0947         goto error;
0948     }
0949 
0950     err = bind_interdomain_evtchn_to_irq_lateeoi(dev, evtchn);
0951     if (err < 0)
0952         goto error;
0953     fedata->irq = err;
0954 
0955     err = request_threaded_irq(fedata->irq, NULL, pvcalls_back_event,
0956                    IRQF_ONESHOT, "pvcalls-back", dev);
0957     if (err < 0)
0958         goto error;
0959 
0960     err = xenbus_map_ring_valloc(dev, &ring_ref, 1,
0961                      (void **)&fedata->sring);
0962     if (err < 0)
0963         goto error;
0964 
0965     BACK_RING_INIT(&fedata->ring, fedata->sring, XEN_PAGE_SIZE * 1);
0966     fedata->dev = dev;
0967 
0968     INIT_LIST_HEAD(&fedata->socket_mappings);
0969     INIT_RADIX_TREE(&fedata->socketpass_mappings, GFP_KERNEL);
0970     sema_init(&fedata->socket_lock, 1);
0971     dev_set_drvdata(&dev->dev, fedata);
0972 
0973     down(&pvcalls_back_global.frontends_lock);
0974     list_add_tail(&fedata->list, &pvcalls_back_global.frontends);
0975     up(&pvcalls_back_global.frontends_lock);
0976 
0977     return 0;
0978 
0979  error:
0980     if (fedata->irq >= 0)
0981         unbind_from_irqhandler(fedata->irq, dev);
0982     if (fedata->sring != NULL)
0983         xenbus_unmap_ring_vfree(dev, fedata->sring);
0984     kfree(fedata);
0985     return err;
0986 }
0987 
0988 static int backend_disconnect(struct xenbus_device *dev)
0989 {
0990     struct pvcalls_fedata *fedata;
0991     struct sock_mapping *map, *n;
0992     struct sockpass_mapping *mappass;
0993     struct radix_tree_iter iter;
0994     void **slot;
0995 
0996 
0997     fedata = dev_get_drvdata(&dev->dev);
0998 
0999     down(&fedata->socket_lock);
1000     list_for_each_entry_safe(map, n, &fedata->socket_mappings, list) {
1001         list_del(&map->list);
1002         pvcalls_back_release_active(dev, fedata, map);
1003     }
1004 
1005     radix_tree_for_each_slot(slot, &fedata->socketpass_mappings, &iter, 0) {
1006         mappass = radix_tree_deref_slot(slot);
1007         if (!mappass)
1008             continue;
1009         if (radix_tree_exception(mappass)) {
1010             if (radix_tree_deref_retry(mappass))
1011                 slot = radix_tree_iter_retry(&iter);
1012         } else {
1013             radix_tree_delete(&fedata->socketpass_mappings,
1014                       mappass->id);
1015             pvcalls_back_release_passive(dev, fedata, mappass);
1016         }
1017     }
1018     up(&fedata->socket_lock);
1019 
1020     unbind_from_irqhandler(fedata->irq, dev);
1021     xenbus_unmap_ring_vfree(dev, fedata->sring);
1022 
1023     list_del(&fedata->list);
1024     kfree(fedata);
1025     dev_set_drvdata(&dev->dev, NULL);
1026 
1027     return 0;
1028 }
1029 
1030 static int pvcalls_back_probe(struct xenbus_device *dev,
1031                   const struct xenbus_device_id *id)
1032 {
1033     int err, abort;
1034     struct xenbus_transaction xbt;
1035 
1036 again:
1037     abort = 1;
1038 
1039     err = xenbus_transaction_start(&xbt);
1040     if (err) {
1041         pr_warn("%s cannot create xenstore transaction\n", __func__);
1042         return err;
1043     }
1044 
1045     err = xenbus_printf(xbt, dev->nodename, "versions", "%s",
1046                 PVCALLS_VERSIONS);
1047     if (err) {
1048         pr_warn("%s write out 'versions' failed\n", __func__);
1049         goto abort;
1050     }
1051 
1052     err = xenbus_printf(xbt, dev->nodename, "max-page-order", "%u",
1053                 MAX_RING_ORDER);
1054     if (err) {
1055         pr_warn("%s write out 'max-page-order' failed\n", __func__);
1056         goto abort;
1057     }
1058 
1059     err = xenbus_printf(xbt, dev->nodename, "function-calls",
1060                 XENBUS_FUNCTIONS_CALLS);
1061     if (err) {
1062         pr_warn("%s write out 'function-calls' failed\n", __func__);
1063         goto abort;
1064     }
1065 
1066     abort = 0;
1067 abort:
1068     err = xenbus_transaction_end(xbt, abort);
1069     if (err) {
1070         if (err == -EAGAIN && !abort)
1071             goto again;
1072         pr_warn("%s cannot complete xenstore transaction\n", __func__);
1073         return err;
1074     }
1075 
1076     if (abort)
1077         return -EFAULT;
1078 
1079     xenbus_switch_state(dev, XenbusStateInitWait);
1080 
1081     return 0;
1082 }
1083 
1084 static void set_backend_state(struct xenbus_device *dev,
1085                   enum xenbus_state state)
1086 {
1087     while (dev->state != state) {
1088         switch (dev->state) {
1089         case XenbusStateClosed:
1090             switch (state) {
1091             case XenbusStateInitWait:
1092             case XenbusStateConnected:
1093                 xenbus_switch_state(dev, XenbusStateInitWait);
1094                 break;
1095             case XenbusStateClosing:
1096                 xenbus_switch_state(dev, XenbusStateClosing);
1097                 break;
1098             default:
1099                 WARN_ON(1);
1100             }
1101             break;
1102         case XenbusStateInitWait:
1103         case XenbusStateInitialised:
1104             switch (state) {
1105             case XenbusStateConnected:
1106                 if (backend_connect(dev))
1107                     return;
1108                 xenbus_switch_state(dev, XenbusStateConnected);
1109                 break;
1110             case XenbusStateClosing:
1111             case XenbusStateClosed:
1112                 xenbus_switch_state(dev, XenbusStateClosing);
1113                 break;
1114             default:
1115                 WARN_ON(1);
1116             }
1117             break;
1118         case XenbusStateConnected:
1119             switch (state) {
1120             case XenbusStateInitWait:
1121             case XenbusStateClosing:
1122             case XenbusStateClosed:
1123                 down(&pvcalls_back_global.frontends_lock);
1124                 backend_disconnect(dev);
1125                 up(&pvcalls_back_global.frontends_lock);
1126                 xenbus_switch_state(dev, XenbusStateClosing);
1127                 break;
1128             default:
1129                 WARN_ON(1);
1130             }
1131             break;
1132         case XenbusStateClosing:
1133             switch (state) {
1134             case XenbusStateInitWait:
1135             case XenbusStateConnected:
1136             case XenbusStateClosed:
1137                 xenbus_switch_state(dev, XenbusStateClosed);
1138                 break;
1139             default:
1140                 WARN_ON(1);
1141             }
1142             break;
1143         default:
1144             WARN_ON(1);
1145         }
1146     }
1147 }
1148 
1149 static void pvcalls_back_changed(struct xenbus_device *dev,
1150                  enum xenbus_state frontend_state)
1151 {
1152     switch (frontend_state) {
1153     case XenbusStateInitialising:
1154         set_backend_state(dev, XenbusStateInitWait);
1155         break;
1156 
1157     case XenbusStateInitialised:
1158     case XenbusStateConnected:
1159         set_backend_state(dev, XenbusStateConnected);
1160         break;
1161 
1162     case XenbusStateClosing:
1163         set_backend_state(dev, XenbusStateClosing);
1164         break;
1165 
1166     case XenbusStateClosed:
1167         set_backend_state(dev, XenbusStateClosed);
1168         if (xenbus_dev_is_online(dev))
1169             break;
1170         device_unregister(&dev->dev);
1171         break;
1172     case XenbusStateUnknown:
1173         set_backend_state(dev, XenbusStateClosed);
1174         device_unregister(&dev->dev);
1175         break;
1176 
1177     default:
1178         xenbus_dev_fatal(dev, -EINVAL, "saw state %d at frontend",
1179                  frontend_state);
1180         break;
1181     }
1182 }
1183 
1184 static int pvcalls_back_remove(struct xenbus_device *dev)
1185 {
1186     return 0;
1187 }
1188 
1189 static int pvcalls_back_uevent(struct xenbus_device *xdev,
1190                    struct kobj_uevent_env *env)
1191 {
1192     return 0;
1193 }
1194 
1195 static const struct xenbus_device_id pvcalls_back_ids[] = {
1196     { "pvcalls" },
1197     { "" }
1198 };
1199 
1200 static struct xenbus_driver pvcalls_back_driver = {
1201     .ids = pvcalls_back_ids,
1202     .probe = pvcalls_back_probe,
1203     .remove = pvcalls_back_remove,
1204     .uevent = pvcalls_back_uevent,
1205     .otherend_changed = pvcalls_back_changed,
1206 };
1207 
1208 static int __init pvcalls_back_init(void)
1209 {
1210     int ret;
1211 
1212     if (!xen_domain())
1213         return -ENODEV;
1214 
1215     ret = xenbus_register_backend(&pvcalls_back_driver);
1216     if (ret < 0)
1217         return ret;
1218 
1219     sema_init(&pvcalls_back_global.frontends_lock, 1);
1220     INIT_LIST_HEAD(&pvcalls_back_global.frontends);
1221     return 0;
1222 }
1223 module_init(pvcalls_back_init);
1224 
1225 static void __exit pvcalls_back_fin(void)
1226 {
1227     struct pvcalls_fedata *fedata, *nfedata;
1228 
1229     down(&pvcalls_back_global.frontends_lock);
1230     list_for_each_entry_safe(fedata, nfedata,
1231                  &pvcalls_back_global.frontends, list) {
1232         backend_disconnect(fedata->dev);
1233     }
1234     up(&pvcalls_back_global.frontends_lock);
1235 
1236     xenbus_unregister_driver(&pvcalls_back_driver);
1237 }
1238 
1239 module_exit(pvcalls_back_fin);
1240 
1241 MODULE_DESCRIPTION("Xen PV Calls backend driver");
1242 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1243 MODULE_LICENSE("GPL");