Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause
0002 /*
0003  * Copyright(c) 2016 Intel Corporation.
0004  */
0005 
0006 #include <linux/slab.h>
0007 #include <linux/vmalloc.h>
0008 #include <rdma/ib_umem.h>
0009 #include <rdma/rdma_vt.h>
0010 #include "vt.h"
0011 #include "mr.h"
0012 #include "trace.h"
0013 
0014 /**
0015  * rvt_driver_mr_init - Init MR resources per driver
0016  * @rdi: rvt dev struct
0017  *
0018  * Do any intilization needed when a driver registers with rdmavt.
0019  *
0020  * Return: 0 on success or errno on failure
0021  */
0022 int rvt_driver_mr_init(struct rvt_dev_info *rdi)
0023 {
0024     unsigned int lkey_table_size = rdi->dparms.lkey_table_size;
0025     unsigned lk_tab_size;
0026     int i;
0027 
0028     /*
0029      * The top hfi1_lkey_table_size bits are used to index the
0030      * table.  The lower 8 bits can be owned by the user (copied from
0031      * the LKEY).  The remaining bits act as a generation number or tag.
0032      */
0033     if (!lkey_table_size)
0034         return -EINVAL;
0035 
0036     spin_lock_init(&rdi->lkey_table.lock);
0037 
0038     /* ensure generation is at least 4 bits */
0039     if (lkey_table_size > RVT_MAX_LKEY_TABLE_BITS) {
0040         rvt_pr_warn(rdi, "lkey bits %u too large, reduced to %u\n",
0041                 lkey_table_size, RVT_MAX_LKEY_TABLE_BITS);
0042         rdi->dparms.lkey_table_size = RVT_MAX_LKEY_TABLE_BITS;
0043         lkey_table_size = rdi->dparms.lkey_table_size;
0044     }
0045     rdi->lkey_table.max = 1 << lkey_table_size;
0046     rdi->lkey_table.shift = 32 - lkey_table_size;
0047     lk_tab_size = rdi->lkey_table.max * sizeof(*rdi->lkey_table.table);
0048     rdi->lkey_table.table = (struct rvt_mregion __rcu **)
0049                    vmalloc_node(lk_tab_size, rdi->dparms.node);
0050     if (!rdi->lkey_table.table)
0051         return -ENOMEM;
0052 
0053     RCU_INIT_POINTER(rdi->dma_mr, NULL);
0054     for (i = 0; i < rdi->lkey_table.max; i++)
0055         RCU_INIT_POINTER(rdi->lkey_table.table[i], NULL);
0056 
0057     rdi->dparms.props.max_mr = rdi->lkey_table.max;
0058     return 0;
0059 }
0060 
0061 /**
0062  * rvt_mr_exit - clean up MR
0063  * @rdi: rvt dev structure
0064  *
0065  * called when drivers have unregistered or perhaps failed to register with us
0066  */
0067 void rvt_mr_exit(struct rvt_dev_info *rdi)
0068 {
0069     if (rdi->dma_mr)
0070         rvt_pr_err(rdi, "DMA MR not null!\n");
0071 
0072     vfree(rdi->lkey_table.table);
0073 }
0074 
0075 static void rvt_deinit_mregion(struct rvt_mregion *mr)
0076 {
0077     int i = mr->mapsz;
0078 
0079     mr->mapsz = 0;
0080     while (i)
0081         kfree(mr->map[--i]);
0082     percpu_ref_exit(&mr->refcount);
0083 }
0084 
0085 static void __rvt_mregion_complete(struct percpu_ref *ref)
0086 {
0087     struct rvt_mregion *mr = container_of(ref, struct rvt_mregion,
0088                           refcount);
0089 
0090     complete(&mr->comp);
0091 }
0092 
0093 static int rvt_init_mregion(struct rvt_mregion *mr, struct ib_pd *pd,
0094                 int count, unsigned int percpu_flags)
0095 {
0096     int m, i = 0;
0097     struct rvt_dev_info *dev = ib_to_rvt(pd->device);
0098 
0099     mr->mapsz = 0;
0100     m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
0101     for (; i < m; i++) {
0102         mr->map[i] = kzalloc_node(sizeof(*mr->map[0]), GFP_KERNEL,
0103                       dev->dparms.node);
0104         if (!mr->map[i])
0105             goto bail;
0106         mr->mapsz++;
0107     }
0108     init_completion(&mr->comp);
0109     /* count returning the ptr to user */
0110     if (percpu_ref_init(&mr->refcount, &__rvt_mregion_complete,
0111                 percpu_flags, GFP_KERNEL))
0112         goto bail;
0113 
0114     atomic_set(&mr->lkey_invalid, 0);
0115     mr->pd = pd;
0116     mr->max_segs = count;
0117     return 0;
0118 bail:
0119     rvt_deinit_mregion(mr);
0120     return -ENOMEM;
0121 }
0122 
0123 /**
0124  * rvt_alloc_lkey - allocate an lkey
0125  * @mr: memory region that this lkey protects
0126  * @dma_region: 0->normal key, 1->restricted DMA key
0127  *
0128  * Returns 0 if successful, otherwise returns -errno.
0129  *
0130  * Increments mr reference count as required.
0131  *
0132  * Sets the lkey field mr for non-dma regions.
0133  *
0134  */
0135 static int rvt_alloc_lkey(struct rvt_mregion *mr, int dma_region)
0136 {
0137     unsigned long flags;
0138     u32 r;
0139     u32 n;
0140     int ret = 0;
0141     struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
0142     struct rvt_lkey_table *rkt = &dev->lkey_table;
0143 
0144     rvt_get_mr(mr);
0145     spin_lock_irqsave(&rkt->lock, flags);
0146 
0147     /* special case for dma_mr lkey == 0 */
0148     if (dma_region) {
0149         struct rvt_mregion *tmr;
0150 
0151         tmr = rcu_access_pointer(dev->dma_mr);
0152         if (!tmr) {
0153             mr->lkey_published = 1;
0154             /* Insure published written first */
0155             rcu_assign_pointer(dev->dma_mr, mr);
0156             rvt_get_mr(mr);
0157         }
0158         goto success;
0159     }
0160 
0161     /* Find the next available LKEY */
0162     r = rkt->next;
0163     n = r;
0164     for (;;) {
0165         if (!rcu_access_pointer(rkt->table[r]))
0166             break;
0167         r = (r + 1) & (rkt->max - 1);
0168         if (r == n)
0169             goto bail;
0170     }
0171     rkt->next = (r + 1) & (rkt->max - 1);
0172     /*
0173      * Make sure lkey is never zero which is reserved to indicate an
0174      * unrestricted LKEY.
0175      */
0176     rkt->gen++;
0177     /*
0178      * bits are capped to ensure enough bits for generation number
0179      */
0180     mr->lkey = (r << (32 - dev->dparms.lkey_table_size)) |
0181         ((((1 << (24 - dev->dparms.lkey_table_size)) - 1) & rkt->gen)
0182          << 8);
0183     if (mr->lkey == 0) {
0184         mr->lkey |= 1 << 8;
0185         rkt->gen++;
0186     }
0187     mr->lkey_published = 1;
0188     /* Insure published written first */
0189     rcu_assign_pointer(rkt->table[r], mr);
0190 success:
0191     spin_unlock_irqrestore(&rkt->lock, flags);
0192 out:
0193     return ret;
0194 bail:
0195     rvt_put_mr(mr);
0196     spin_unlock_irqrestore(&rkt->lock, flags);
0197     ret = -ENOMEM;
0198     goto out;
0199 }
0200 
0201 /**
0202  * rvt_free_lkey - free an lkey
0203  * @mr: mr to free from tables
0204  */
0205 static void rvt_free_lkey(struct rvt_mregion *mr)
0206 {
0207     unsigned long flags;
0208     u32 lkey = mr->lkey;
0209     u32 r;
0210     struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
0211     struct rvt_lkey_table *rkt = &dev->lkey_table;
0212     int freed = 0;
0213 
0214     spin_lock_irqsave(&rkt->lock, flags);
0215     if (!lkey) {
0216         if (mr->lkey_published) {
0217             mr->lkey_published = 0;
0218             /* insure published is written before pointer */
0219             rcu_assign_pointer(dev->dma_mr, NULL);
0220             rvt_put_mr(mr);
0221         }
0222     } else {
0223         if (!mr->lkey_published)
0224             goto out;
0225         r = lkey >> (32 - dev->dparms.lkey_table_size);
0226         mr->lkey_published = 0;
0227         /* insure published is written before pointer */
0228         rcu_assign_pointer(rkt->table[r], NULL);
0229     }
0230     freed++;
0231 out:
0232     spin_unlock_irqrestore(&rkt->lock, flags);
0233     if (freed)
0234         percpu_ref_kill(&mr->refcount);
0235 }
0236 
0237 static struct rvt_mr *__rvt_alloc_mr(int count, struct ib_pd *pd)
0238 {
0239     struct rvt_mr *mr;
0240     int rval = -ENOMEM;
0241     int m;
0242 
0243     /* Allocate struct plus pointers to first level page tables. */
0244     m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
0245     mr = kzalloc(struct_size(mr, mr.map, m), GFP_KERNEL);
0246     if (!mr)
0247         goto bail;
0248 
0249     rval = rvt_init_mregion(&mr->mr, pd, count, 0);
0250     if (rval)
0251         goto bail;
0252     /*
0253      * ib_reg_phys_mr() will initialize mr->ibmr except for
0254      * lkey and rkey.
0255      */
0256     rval = rvt_alloc_lkey(&mr->mr, 0);
0257     if (rval)
0258         goto bail_mregion;
0259     mr->ibmr.lkey = mr->mr.lkey;
0260     mr->ibmr.rkey = mr->mr.lkey;
0261 done:
0262     return mr;
0263 
0264 bail_mregion:
0265     rvt_deinit_mregion(&mr->mr);
0266 bail:
0267     kfree(mr);
0268     mr = ERR_PTR(rval);
0269     goto done;
0270 }
0271 
0272 static void __rvt_free_mr(struct rvt_mr *mr)
0273 {
0274     rvt_free_lkey(&mr->mr);
0275     rvt_deinit_mregion(&mr->mr);
0276     kfree(mr);
0277 }
0278 
0279 /**
0280  * rvt_get_dma_mr - get a DMA memory region
0281  * @pd: protection domain for this memory region
0282  * @acc: access flags
0283  *
0284  * Return: the memory region on success, otherwise returns an errno.
0285  */
0286 struct ib_mr *rvt_get_dma_mr(struct ib_pd *pd, int acc)
0287 {
0288     struct rvt_mr *mr;
0289     struct ib_mr *ret;
0290     int rval;
0291 
0292     if (ibpd_to_rvtpd(pd)->user)
0293         return ERR_PTR(-EPERM);
0294 
0295     mr = kzalloc(sizeof(*mr), GFP_KERNEL);
0296     if (!mr) {
0297         ret = ERR_PTR(-ENOMEM);
0298         goto bail;
0299     }
0300 
0301     rval = rvt_init_mregion(&mr->mr, pd, 0, 0);
0302     if (rval) {
0303         ret = ERR_PTR(rval);
0304         goto bail;
0305     }
0306 
0307     rval = rvt_alloc_lkey(&mr->mr, 1);
0308     if (rval) {
0309         ret = ERR_PTR(rval);
0310         goto bail_mregion;
0311     }
0312 
0313     mr->mr.access_flags = acc;
0314     ret = &mr->ibmr;
0315 done:
0316     return ret;
0317 
0318 bail_mregion:
0319     rvt_deinit_mregion(&mr->mr);
0320 bail:
0321     kfree(mr);
0322     goto done;
0323 }
0324 
0325 /**
0326  * rvt_reg_user_mr - register a userspace memory region
0327  * @pd: protection domain for this memory region
0328  * @start: starting userspace address
0329  * @length: length of region to register
0330  * @virt_addr: associated virtual address
0331  * @mr_access_flags: access flags for this memory region
0332  * @udata: unused by the driver
0333  *
0334  * Return: the memory region on success, otherwise returns an errno.
0335  */
0336 struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
0337                   u64 virt_addr, int mr_access_flags,
0338                   struct ib_udata *udata)
0339 {
0340     struct rvt_mr *mr;
0341     struct ib_umem *umem;
0342     struct sg_page_iter sg_iter;
0343     int n, m;
0344     struct ib_mr *ret;
0345 
0346     if (length == 0)
0347         return ERR_PTR(-EINVAL);
0348 
0349     umem = ib_umem_get(pd->device, start, length, mr_access_flags);
0350     if (IS_ERR(umem))
0351         return (void *)umem;
0352 
0353     n = ib_umem_num_pages(umem);
0354 
0355     mr = __rvt_alloc_mr(n, pd);
0356     if (IS_ERR(mr)) {
0357         ret = (struct ib_mr *)mr;
0358         goto bail_umem;
0359     }
0360 
0361     mr->mr.user_base = start;
0362     mr->mr.iova = virt_addr;
0363     mr->mr.length = length;
0364     mr->mr.offset = ib_umem_offset(umem);
0365     mr->mr.access_flags = mr_access_flags;
0366     mr->umem = umem;
0367 
0368     mr->mr.page_shift = PAGE_SHIFT;
0369     m = 0;
0370     n = 0;
0371     for_each_sgtable_page (&umem->sgt_append.sgt, &sg_iter, 0) {
0372         void *vaddr;
0373 
0374         vaddr = page_address(sg_page_iter_page(&sg_iter));
0375         if (!vaddr) {
0376             ret = ERR_PTR(-EINVAL);
0377             goto bail_inval;
0378         }
0379         mr->mr.map[m]->segs[n].vaddr = vaddr;
0380         mr->mr.map[m]->segs[n].length = PAGE_SIZE;
0381         trace_rvt_mr_user_seg(&mr->mr, m, n, vaddr, PAGE_SIZE);
0382         if (++n == RVT_SEGSZ) {
0383             m++;
0384             n = 0;
0385         }
0386     }
0387     return &mr->ibmr;
0388 
0389 bail_inval:
0390     __rvt_free_mr(mr);
0391 
0392 bail_umem:
0393     ib_umem_release(umem);
0394 
0395     return ret;
0396 }
0397 
0398 /**
0399  * rvt_dereg_clean_qp_cb - callback from iterator
0400  * @qp: the qp
0401  * @v: the mregion (as u64)
0402  *
0403  * This routine fields the callback for all QPs and
0404  * for QPs in the same PD as the MR will call the
0405  * rvt_qp_mr_clean() to potentially cleanup references.
0406  */
0407 static void rvt_dereg_clean_qp_cb(struct rvt_qp *qp, u64 v)
0408 {
0409     struct rvt_mregion *mr = (struct rvt_mregion *)v;
0410 
0411     /* skip PDs that are not ours */
0412     if (mr->pd != qp->ibqp.pd)
0413         return;
0414     rvt_qp_mr_clean(qp, mr->lkey);
0415 }
0416 
0417 /**
0418  * rvt_dereg_clean_qps - find QPs for reference cleanup
0419  * @mr: the MR that is being deregistered
0420  *
0421  * This routine iterates RC QPs looking for references
0422  * to the lkey noted in mr.
0423  */
0424 static void rvt_dereg_clean_qps(struct rvt_mregion *mr)
0425 {
0426     struct rvt_dev_info *rdi = ib_to_rvt(mr->pd->device);
0427 
0428     rvt_qp_iter(rdi, (u64)mr, rvt_dereg_clean_qp_cb);
0429 }
0430 
0431 /**
0432  * rvt_check_refs - check references
0433  * @mr: the megion
0434  * @t: the caller identification
0435  *
0436  * This routine checks MRs holding a reference during
0437  * when being de-registered.
0438  *
0439  * If the count is non-zero, the code calls a clean routine then
0440  * waits for the timeout for the count to zero.
0441  */
0442 static int rvt_check_refs(struct rvt_mregion *mr, const char *t)
0443 {
0444     unsigned long timeout;
0445     struct rvt_dev_info *rdi = ib_to_rvt(mr->pd->device);
0446 
0447     if (mr->lkey) {
0448         /* avoid dma mr */
0449         rvt_dereg_clean_qps(mr);
0450         /* @mr was indexed on rcu protected @lkey_table */
0451         synchronize_rcu();
0452     }
0453 
0454     timeout = wait_for_completion_timeout(&mr->comp, 5 * HZ);
0455     if (!timeout) {
0456         rvt_pr_err(rdi,
0457                "%s timeout mr %p pd %p lkey %x refcount %ld\n",
0458                t, mr, mr->pd, mr->lkey,
0459                atomic_long_read(&mr->refcount.data->count));
0460         rvt_get_mr(mr);
0461         return -EBUSY;
0462     }
0463     return 0;
0464 }
0465 
0466 /**
0467  * rvt_mr_has_lkey - is MR
0468  * @mr: the mregion
0469  * @lkey: the lkey
0470  */
0471 bool rvt_mr_has_lkey(struct rvt_mregion *mr, u32 lkey)
0472 {
0473     return mr && lkey == mr->lkey;
0474 }
0475 
0476 /**
0477  * rvt_ss_has_lkey - is mr in sge tests
0478  * @ss: the sge state
0479  * @lkey: the lkey
0480  *
0481  * This code tests for an MR in the indicated
0482  * sge state.
0483  */
0484 bool rvt_ss_has_lkey(struct rvt_sge_state *ss, u32 lkey)
0485 {
0486     int i;
0487     bool rval = false;
0488 
0489     if (!ss->num_sge)
0490         return rval;
0491     /* first one */
0492     rval = rvt_mr_has_lkey(ss->sge.mr, lkey);
0493     /* any others */
0494     for (i = 0; !rval && i < ss->num_sge - 1; i++)
0495         rval = rvt_mr_has_lkey(ss->sg_list[i].mr, lkey);
0496     return rval;
0497 }
0498 
0499 /**
0500  * rvt_dereg_mr - unregister and free a memory region
0501  * @ibmr: the memory region to free
0502  * @udata: unused by the driver
0503  *
0504  * Note that this is called to free MRs created by rvt_get_dma_mr()
0505  * or rvt_reg_user_mr().
0506  *
0507  * Returns 0 on success.
0508  */
0509 int rvt_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
0510 {
0511     struct rvt_mr *mr = to_imr(ibmr);
0512     int ret;
0513 
0514     rvt_free_lkey(&mr->mr);
0515 
0516     rvt_put_mr(&mr->mr); /* will set completion if last */
0517     ret = rvt_check_refs(&mr->mr, __func__);
0518     if (ret)
0519         goto out;
0520     rvt_deinit_mregion(&mr->mr);
0521     ib_umem_release(mr->umem);
0522     kfree(mr);
0523 out:
0524     return ret;
0525 }
0526 
0527 /**
0528  * rvt_alloc_mr - Allocate a memory region usable with the
0529  * @pd: protection domain for this memory region
0530  * @mr_type: mem region type
0531  * @max_num_sg: Max number of segments allowed
0532  *
0533  * Return: the memory region on success, otherwise return an errno.
0534  */
0535 struct ib_mr *rvt_alloc_mr(struct ib_pd *pd, enum ib_mr_type mr_type,
0536                u32 max_num_sg)
0537 {
0538     struct rvt_mr *mr;
0539 
0540     if (mr_type != IB_MR_TYPE_MEM_REG)
0541         return ERR_PTR(-EINVAL);
0542 
0543     mr = __rvt_alloc_mr(max_num_sg, pd);
0544     if (IS_ERR(mr))
0545         return (struct ib_mr *)mr;
0546 
0547     return &mr->ibmr;
0548 }
0549 
0550 /**
0551  * rvt_set_page - page assignment function called by ib_sg_to_pages
0552  * @ibmr: memory region
0553  * @addr: dma address of mapped page
0554  *
0555  * Return: 0 on success
0556  */
0557 static int rvt_set_page(struct ib_mr *ibmr, u64 addr)
0558 {
0559     struct rvt_mr *mr = to_imr(ibmr);
0560     u32 ps = 1 << mr->mr.page_shift;
0561     u32 mapped_segs = mr->mr.length >> mr->mr.page_shift;
0562     int m, n;
0563 
0564     if (unlikely(mapped_segs == mr->mr.max_segs))
0565         return -ENOMEM;
0566 
0567     m = mapped_segs / RVT_SEGSZ;
0568     n = mapped_segs % RVT_SEGSZ;
0569     mr->mr.map[m]->segs[n].vaddr = (void *)addr;
0570     mr->mr.map[m]->segs[n].length = ps;
0571     mr->mr.length += ps;
0572     trace_rvt_mr_page_seg(&mr->mr, m, n, (void *)addr, ps);
0573 
0574     return 0;
0575 }
0576 
0577 /**
0578  * rvt_map_mr_sg - map sg list and set it the memory region
0579  * @ibmr: memory region
0580  * @sg: dma mapped scatterlist
0581  * @sg_nents: number of entries in sg
0582  * @sg_offset: offset in bytes into sg
0583  *
0584  * Overwrite rvt_mr length with mr length calculated by ib_sg_to_pages.
0585  *
0586  * Return: number of sg elements mapped to the memory region
0587  */
0588 int rvt_map_mr_sg(struct ib_mr *ibmr, struct scatterlist *sg,
0589           int sg_nents, unsigned int *sg_offset)
0590 {
0591     struct rvt_mr *mr = to_imr(ibmr);
0592     int ret;
0593 
0594     mr->mr.length = 0;
0595     mr->mr.page_shift = PAGE_SHIFT;
0596     ret = ib_sg_to_pages(ibmr, sg, sg_nents, sg_offset, rvt_set_page);
0597     mr->mr.user_base = ibmr->iova;
0598     mr->mr.iova = ibmr->iova;
0599     mr->mr.offset = ibmr->iova - (u64)mr->mr.map[0]->segs[0].vaddr;
0600     mr->mr.length = (size_t)ibmr->length;
0601     trace_rvt_map_mr_sg(ibmr, sg_nents, sg_offset);
0602     return ret;
0603 }
0604 
0605 /**
0606  * rvt_fast_reg_mr - fast register physical MR
0607  * @qp: the queue pair where the work request comes from
0608  * @ibmr: the memory region to be registered
0609  * @key: updated key for this memory region
0610  * @access: access flags for this memory region
0611  *
0612  * Returns 0 on success.
0613  */
0614 int rvt_fast_reg_mr(struct rvt_qp *qp, struct ib_mr *ibmr, u32 key,
0615             int access)
0616 {
0617     struct rvt_mr *mr = to_imr(ibmr);
0618 
0619     if (qp->ibqp.pd != mr->mr.pd)
0620         return -EACCES;
0621 
0622     /* not applicable to dma MR or user MR */
0623     if (!mr->mr.lkey || mr->umem)
0624         return -EINVAL;
0625 
0626     if ((key & 0xFFFFFF00) != (mr->mr.lkey & 0xFFFFFF00))
0627         return -EINVAL;
0628 
0629     ibmr->lkey = key;
0630     ibmr->rkey = key;
0631     mr->mr.lkey = key;
0632     mr->mr.access_flags = access;
0633     mr->mr.iova = ibmr->iova;
0634     atomic_set(&mr->mr.lkey_invalid, 0);
0635 
0636     return 0;
0637 }
0638 EXPORT_SYMBOL(rvt_fast_reg_mr);
0639 
0640 /**
0641  * rvt_invalidate_rkey - invalidate an MR rkey
0642  * @qp: queue pair associated with the invalidate op
0643  * @rkey: rkey to invalidate
0644  *
0645  * Returns 0 on success.
0646  */
0647 int rvt_invalidate_rkey(struct rvt_qp *qp, u32 rkey)
0648 {
0649     struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
0650     struct rvt_lkey_table *rkt = &dev->lkey_table;
0651     struct rvt_mregion *mr;
0652 
0653     if (rkey == 0)
0654         return -EINVAL;
0655 
0656     rcu_read_lock();
0657     mr = rcu_dereference(
0658         rkt->table[(rkey >> (32 - dev->dparms.lkey_table_size))]);
0659     if (unlikely(!mr || mr->lkey != rkey || qp->ibqp.pd != mr->pd))
0660         goto bail;
0661 
0662     atomic_set(&mr->lkey_invalid, 1);
0663     rcu_read_unlock();
0664     return 0;
0665 
0666 bail:
0667     rcu_read_unlock();
0668     return -EINVAL;
0669 }
0670 EXPORT_SYMBOL(rvt_invalidate_rkey);
0671 
0672 /**
0673  * rvt_sge_adjacent - is isge compressible
0674  * @last_sge: last outgoing SGE written
0675  * @sge: SGE to check
0676  *
0677  * If adjacent will update last_sge to add length.
0678  *
0679  * Return: true if isge is adjacent to last sge
0680  */
0681 static inline bool rvt_sge_adjacent(struct rvt_sge *last_sge,
0682                     struct ib_sge *sge)
0683 {
0684     if (last_sge && sge->lkey == last_sge->mr->lkey &&
0685         ((uint64_t)(last_sge->vaddr + last_sge->length) == sge->addr)) {
0686         if (sge->lkey) {
0687             if (unlikely((sge->addr - last_sge->mr->user_base +
0688                   sge->length > last_sge->mr->length)))
0689                 return false; /* overrun, caller will catch */
0690         } else {
0691             last_sge->length += sge->length;
0692         }
0693         last_sge->sge_length += sge->length;
0694         trace_rvt_sge_adjacent(last_sge, sge);
0695         return true;
0696     }
0697     return false;
0698 }
0699 
0700 /**
0701  * rvt_lkey_ok - check IB SGE for validity and initialize
0702  * @rkt: table containing lkey to check SGE against
0703  * @pd: protection domain
0704  * @isge: outgoing internal SGE
0705  * @last_sge: last outgoing SGE written
0706  * @sge: SGE to check
0707  * @acc: access flags
0708  *
0709  * Check the IB SGE for validity and initialize our internal version
0710  * of it.
0711  *
0712  * Increments the reference count when a new sge is stored.
0713  *
0714  * Return: 0 if compressed, 1 if added , otherwise returns -errno.
0715  */
0716 int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd,
0717         struct rvt_sge *isge, struct rvt_sge *last_sge,
0718         struct ib_sge *sge, int acc)
0719 {
0720     struct rvt_mregion *mr;
0721     unsigned n, m;
0722     size_t off;
0723 
0724     /*
0725      * We use LKEY == zero for kernel virtual addresses
0726      * (see rvt_get_dma_mr()).
0727      */
0728     if (sge->lkey == 0) {
0729         struct rvt_dev_info *dev = ib_to_rvt(pd->ibpd.device);
0730 
0731         if (pd->user)
0732             return -EINVAL;
0733         if (rvt_sge_adjacent(last_sge, sge))
0734             return 0;
0735         rcu_read_lock();
0736         mr = rcu_dereference(dev->dma_mr);
0737         if (!mr)
0738             goto bail;
0739         rvt_get_mr(mr);
0740         rcu_read_unlock();
0741 
0742         isge->mr = mr;
0743         isge->vaddr = (void *)sge->addr;
0744         isge->length = sge->length;
0745         isge->sge_length = sge->length;
0746         isge->m = 0;
0747         isge->n = 0;
0748         goto ok;
0749     }
0750     if (rvt_sge_adjacent(last_sge, sge))
0751         return 0;
0752     rcu_read_lock();
0753     mr = rcu_dereference(rkt->table[sge->lkey >> rkt->shift]);
0754     if (!mr)
0755         goto bail;
0756     rvt_get_mr(mr);
0757     if (!READ_ONCE(mr->lkey_published))
0758         goto bail_unref;
0759 
0760     if (unlikely(atomic_read(&mr->lkey_invalid) ||
0761              mr->lkey != sge->lkey || mr->pd != &pd->ibpd))
0762         goto bail_unref;
0763 
0764     off = sge->addr - mr->user_base;
0765     if (unlikely(sge->addr < mr->user_base ||
0766              off + sge->length > mr->length ||
0767              (mr->access_flags & acc) != acc))
0768         goto bail_unref;
0769     rcu_read_unlock();
0770 
0771     off += mr->offset;
0772     if (mr->page_shift) {
0773         /*
0774          * page sizes are uniform power of 2 so no loop is necessary
0775          * entries_spanned_by_off is the number of times the loop below
0776          * would have executed.
0777         */
0778         size_t entries_spanned_by_off;
0779 
0780         entries_spanned_by_off = off >> mr->page_shift;
0781         off -= (entries_spanned_by_off << mr->page_shift);
0782         m = entries_spanned_by_off / RVT_SEGSZ;
0783         n = entries_spanned_by_off % RVT_SEGSZ;
0784     } else {
0785         m = 0;
0786         n = 0;
0787         while (off >= mr->map[m]->segs[n].length) {
0788             off -= mr->map[m]->segs[n].length;
0789             n++;
0790             if (n >= RVT_SEGSZ) {
0791                 m++;
0792                 n = 0;
0793             }
0794         }
0795     }
0796     isge->mr = mr;
0797     isge->vaddr = mr->map[m]->segs[n].vaddr + off;
0798     isge->length = mr->map[m]->segs[n].length - off;
0799     isge->sge_length = sge->length;
0800     isge->m = m;
0801     isge->n = n;
0802 ok:
0803     trace_rvt_sge_new(isge, sge);
0804     return 1;
0805 bail_unref:
0806     rvt_put_mr(mr);
0807 bail:
0808     rcu_read_unlock();
0809     return -EINVAL;
0810 }
0811 EXPORT_SYMBOL(rvt_lkey_ok);
0812 
0813 /**
0814  * rvt_rkey_ok - check the IB virtual address, length, and RKEY
0815  * @qp: qp for validation
0816  * @sge: SGE state
0817  * @len: length of data
0818  * @vaddr: virtual address to place data
0819  * @rkey: rkey to check
0820  * @acc: access flags
0821  *
0822  * Return: 1 if successful, otherwise 0.
0823  *
0824  * increments the reference count upon success
0825  */
0826 int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge,
0827         u32 len, u64 vaddr, u32 rkey, int acc)
0828 {
0829     struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
0830     struct rvt_lkey_table *rkt = &dev->lkey_table;
0831     struct rvt_mregion *mr;
0832     unsigned n, m;
0833     size_t off;
0834 
0835     /*
0836      * We use RKEY == zero for kernel virtual addresses
0837      * (see rvt_get_dma_mr()).
0838      */
0839     rcu_read_lock();
0840     if (rkey == 0) {
0841         struct rvt_pd *pd = ibpd_to_rvtpd(qp->ibqp.pd);
0842         struct rvt_dev_info *rdi = ib_to_rvt(pd->ibpd.device);
0843 
0844         if (pd->user)
0845             goto bail;
0846         mr = rcu_dereference(rdi->dma_mr);
0847         if (!mr)
0848             goto bail;
0849         rvt_get_mr(mr);
0850         rcu_read_unlock();
0851 
0852         sge->mr = mr;
0853         sge->vaddr = (void *)vaddr;
0854         sge->length = len;
0855         sge->sge_length = len;
0856         sge->m = 0;
0857         sge->n = 0;
0858         goto ok;
0859     }
0860 
0861     mr = rcu_dereference(rkt->table[rkey >> rkt->shift]);
0862     if (!mr)
0863         goto bail;
0864     rvt_get_mr(mr);
0865     /* insure mr read is before test */
0866     if (!READ_ONCE(mr->lkey_published))
0867         goto bail_unref;
0868     if (unlikely(atomic_read(&mr->lkey_invalid) ||
0869              mr->lkey != rkey || qp->ibqp.pd != mr->pd))
0870         goto bail_unref;
0871 
0872     off = vaddr - mr->iova;
0873     if (unlikely(vaddr < mr->iova || off + len > mr->length ||
0874              (mr->access_flags & acc) == 0))
0875         goto bail_unref;
0876     rcu_read_unlock();
0877 
0878     off += mr->offset;
0879     if (mr->page_shift) {
0880         /*
0881          * page sizes are uniform power of 2 so no loop is necessary
0882          * entries_spanned_by_off is the number of times the loop below
0883          * would have executed.
0884         */
0885         size_t entries_spanned_by_off;
0886 
0887         entries_spanned_by_off = off >> mr->page_shift;
0888         off -= (entries_spanned_by_off << mr->page_shift);
0889         m = entries_spanned_by_off / RVT_SEGSZ;
0890         n = entries_spanned_by_off % RVT_SEGSZ;
0891     } else {
0892         m = 0;
0893         n = 0;
0894         while (off >= mr->map[m]->segs[n].length) {
0895             off -= mr->map[m]->segs[n].length;
0896             n++;
0897             if (n >= RVT_SEGSZ) {
0898                 m++;
0899                 n = 0;
0900             }
0901         }
0902     }
0903     sge->mr = mr;
0904     sge->vaddr = mr->map[m]->segs[n].vaddr + off;
0905     sge->length = mr->map[m]->segs[n].length - off;
0906     sge->sge_length = len;
0907     sge->m = m;
0908     sge->n = n;
0909 ok:
0910     return 1;
0911 bail_unref:
0912     rvt_put_mr(mr);
0913 bail:
0914     rcu_read_unlock();
0915     return 0;
0916 }
0917 EXPORT_SYMBOL(rvt_rkey_ok);