Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
0004  * Author: Joerg Roedel <jroedel@suse.de>
0005  */
0006 
0007 #define pr_fmt(fmt)     "AMD-Vi: " fmt
0008 
0009 #include <linux/refcount.h>
0010 #include <linux/mmu_notifier.h>
0011 #include <linux/amd-iommu.h>
0012 #include <linux/mm_types.h>
0013 #include <linux/profile.h>
0014 #include <linux/module.h>
0015 #include <linux/sched.h>
0016 #include <linux/sched/mm.h>
0017 #include <linux/wait.h>
0018 #include <linux/pci.h>
0019 #include <linux/gfp.h>
0020 #include <linux/cc_platform.h>
0021 
0022 #include "amd_iommu.h"
0023 
0024 MODULE_LICENSE("GPL v2");
0025 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
0026 
0027 #define PRI_QUEUE_SIZE      512
0028 
0029 struct pri_queue {
0030     atomic_t inflight;
0031     bool finish;
0032     int status;
0033 };
0034 
0035 struct pasid_state {
0036     struct list_head list;          /* For global state-list */
0037     refcount_t count;               /* Reference count */
0038     unsigned mmu_notifier_count;        /* Counting nested mmu_notifier
0039                            calls */
0040     struct mm_struct *mm;           /* mm_struct for the faults */
0041     struct mmu_notifier mn;                 /* mmu_notifier handle */
0042     struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
0043     struct device_state *device_state;  /* Link to our device_state */
0044     u32 pasid;              /* PASID index */
0045     bool invalid;               /* Used during setup and
0046                            teardown of the pasid */
0047     spinlock_t lock;            /* Protect pri_queues and
0048                            mmu_notifer_count */
0049     wait_queue_head_t wq;           /* To wait for count == 0 */
0050 };
0051 
0052 struct device_state {
0053     struct list_head list;
0054     u32 sbdf;
0055     atomic_t count;
0056     struct pci_dev *pdev;
0057     struct pasid_state **states;
0058     struct iommu_domain *domain;
0059     int pasid_levels;
0060     int max_pasids;
0061     amd_iommu_invalid_ppr_cb inv_ppr_cb;
0062     amd_iommu_invalidate_ctx inv_ctx_cb;
0063     spinlock_t lock;
0064     wait_queue_head_t wq;
0065 };
0066 
0067 struct fault {
0068     struct work_struct work;
0069     struct device_state *dev_state;
0070     struct pasid_state *state;
0071     struct mm_struct *mm;
0072     u64 address;
0073     u32 pasid;
0074     u16 tag;
0075     u16 finish;
0076     u16 flags;
0077 };
0078 
0079 static LIST_HEAD(state_list);
0080 static DEFINE_SPINLOCK(state_lock);
0081 
0082 static struct workqueue_struct *iommu_wq;
0083 
0084 static void free_pasid_states(struct device_state *dev_state);
0085 
0086 static struct device_state *__get_device_state(u32 sbdf)
0087 {
0088     struct device_state *dev_state;
0089 
0090     list_for_each_entry(dev_state, &state_list, list) {
0091         if (dev_state->sbdf == sbdf)
0092             return dev_state;
0093     }
0094 
0095     return NULL;
0096 }
0097 
0098 static struct device_state *get_device_state(u32 sbdf)
0099 {
0100     struct device_state *dev_state;
0101     unsigned long flags;
0102 
0103     spin_lock_irqsave(&state_lock, flags);
0104     dev_state = __get_device_state(sbdf);
0105     if (dev_state != NULL)
0106         atomic_inc(&dev_state->count);
0107     spin_unlock_irqrestore(&state_lock, flags);
0108 
0109     return dev_state;
0110 }
0111 
0112 static void free_device_state(struct device_state *dev_state)
0113 {
0114     struct iommu_group *group;
0115 
0116     /* Get rid of any remaining pasid states */
0117     free_pasid_states(dev_state);
0118 
0119     /*
0120      * Wait until the last reference is dropped before freeing
0121      * the device state.
0122      */
0123     wait_event(dev_state->wq, !atomic_read(&dev_state->count));
0124 
0125     /*
0126      * First detach device from domain - No more PRI requests will arrive
0127      * from that device after it is unbound from the IOMMUv2 domain.
0128      */
0129     group = iommu_group_get(&dev_state->pdev->dev);
0130     if (WARN_ON(!group))
0131         return;
0132 
0133     iommu_detach_group(dev_state->domain, group);
0134 
0135     iommu_group_put(group);
0136 
0137     /* Everything is down now, free the IOMMUv2 domain */
0138     iommu_domain_free(dev_state->domain);
0139 
0140     /* Finally get rid of the device-state */
0141     kfree(dev_state);
0142 }
0143 
0144 static void put_device_state(struct device_state *dev_state)
0145 {
0146     if (atomic_dec_and_test(&dev_state->count))
0147         wake_up(&dev_state->wq);
0148 }
0149 
0150 /* Must be called under dev_state->lock */
0151 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
0152                           u32 pasid, bool alloc)
0153 {
0154     struct pasid_state **root, **ptr;
0155     int level, index;
0156 
0157     level = dev_state->pasid_levels;
0158     root  = dev_state->states;
0159 
0160     while (true) {
0161 
0162         index = (pasid >> (9 * level)) & 0x1ff;
0163         ptr   = &root[index];
0164 
0165         if (level == 0)
0166             break;
0167 
0168         if (*ptr == NULL) {
0169             if (!alloc)
0170                 return NULL;
0171 
0172             *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
0173             if (*ptr == NULL)
0174                 return NULL;
0175         }
0176 
0177         root   = (struct pasid_state **)*ptr;
0178         level -= 1;
0179     }
0180 
0181     return ptr;
0182 }
0183 
0184 static int set_pasid_state(struct device_state *dev_state,
0185                struct pasid_state *pasid_state,
0186                u32 pasid)
0187 {
0188     struct pasid_state **ptr;
0189     unsigned long flags;
0190     int ret;
0191 
0192     spin_lock_irqsave(&dev_state->lock, flags);
0193     ptr = __get_pasid_state_ptr(dev_state, pasid, true);
0194 
0195     ret = -ENOMEM;
0196     if (ptr == NULL)
0197         goto out_unlock;
0198 
0199     ret = -ENOMEM;
0200     if (*ptr != NULL)
0201         goto out_unlock;
0202 
0203     *ptr = pasid_state;
0204 
0205     ret = 0;
0206 
0207 out_unlock:
0208     spin_unlock_irqrestore(&dev_state->lock, flags);
0209 
0210     return ret;
0211 }
0212 
0213 static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
0214 {
0215     struct pasid_state **ptr;
0216     unsigned long flags;
0217 
0218     spin_lock_irqsave(&dev_state->lock, flags);
0219     ptr = __get_pasid_state_ptr(dev_state, pasid, true);
0220 
0221     if (ptr == NULL)
0222         goto out_unlock;
0223 
0224     *ptr = NULL;
0225 
0226 out_unlock:
0227     spin_unlock_irqrestore(&dev_state->lock, flags);
0228 }
0229 
0230 static struct pasid_state *get_pasid_state(struct device_state *dev_state,
0231                        u32 pasid)
0232 {
0233     struct pasid_state **ptr, *ret = NULL;
0234     unsigned long flags;
0235 
0236     spin_lock_irqsave(&dev_state->lock, flags);
0237     ptr = __get_pasid_state_ptr(dev_state, pasid, false);
0238 
0239     if (ptr == NULL)
0240         goto out_unlock;
0241 
0242     ret = *ptr;
0243     if (ret)
0244         refcount_inc(&ret->count);
0245 
0246 out_unlock:
0247     spin_unlock_irqrestore(&dev_state->lock, flags);
0248 
0249     return ret;
0250 }
0251 
0252 static void free_pasid_state(struct pasid_state *pasid_state)
0253 {
0254     kfree(pasid_state);
0255 }
0256 
0257 static void put_pasid_state(struct pasid_state *pasid_state)
0258 {
0259     if (refcount_dec_and_test(&pasid_state->count))
0260         wake_up(&pasid_state->wq);
0261 }
0262 
0263 static void put_pasid_state_wait(struct pasid_state *pasid_state)
0264 {
0265     refcount_dec(&pasid_state->count);
0266     wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
0267     free_pasid_state(pasid_state);
0268 }
0269 
0270 static void unbind_pasid(struct pasid_state *pasid_state)
0271 {
0272     struct iommu_domain *domain;
0273 
0274     domain = pasid_state->device_state->domain;
0275 
0276     /*
0277      * Mark pasid_state as invalid, no more faults will we added to the
0278      * work queue after this is visible everywhere.
0279      */
0280     pasid_state->invalid = true;
0281 
0282     /* Make sure this is visible */
0283     smp_wmb();
0284 
0285     /* After this the device/pasid can't access the mm anymore */
0286     amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
0287 
0288     /* Make sure no more pending faults are in the queue */
0289     flush_workqueue(iommu_wq);
0290 }
0291 
0292 static void free_pasid_states_level1(struct pasid_state **tbl)
0293 {
0294     int i;
0295 
0296     for (i = 0; i < 512; ++i) {
0297         if (tbl[i] == NULL)
0298             continue;
0299 
0300         free_page((unsigned long)tbl[i]);
0301     }
0302 }
0303 
0304 static void free_pasid_states_level2(struct pasid_state **tbl)
0305 {
0306     struct pasid_state **ptr;
0307     int i;
0308 
0309     for (i = 0; i < 512; ++i) {
0310         if (tbl[i] == NULL)
0311             continue;
0312 
0313         ptr = (struct pasid_state **)tbl[i];
0314         free_pasid_states_level1(ptr);
0315     }
0316 }
0317 
0318 static void free_pasid_states(struct device_state *dev_state)
0319 {
0320     struct pasid_state *pasid_state;
0321     int i;
0322 
0323     for (i = 0; i < dev_state->max_pasids; ++i) {
0324         pasid_state = get_pasid_state(dev_state, i);
0325         if (pasid_state == NULL)
0326             continue;
0327 
0328         put_pasid_state(pasid_state);
0329 
0330         /*
0331          * This will call the mn_release function and
0332          * unbind the PASID
0333          */
0334         mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
0335 
0336         put_pasid_state_wait(pasid_state); /* Reference taken in
0337                               amd_iommu_bind_pasid */
0338 
0339         /* Drop reference taken in amd_iommu_bind_pasid */
0340         put_device_state(dev_state);
0341     }
0342 
0343     if (dev_state->pasid_levels == 2)
0344         free_pasid_states_level2(dev_state->states);
0345     else if (dev_state->pasid_levels == 1)
0346         free_pasid_states_level1(dev_state->states);
0347     else
0348         BUG_ON(dev_state->pasid_levels != 0);
0349 
0350     free_page((unsigned long)dev_state->states);
0351 }
0352 
0353 static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
0354 {
0355     return container_of(mn, struct pasid_state, mn);
0356 }
0357 
0358 static void mn_invalidate_range(struct mmu_notifier *mn,
0359                 struct mm_struct *mm,
0360                 unsigned long start, unsigned long end)
0361 {
0362     struct pasid_state *pasid_state;
0363     struct device_state *dev_state;
0364 
0365     pasid_state = mn_to_state(mn);
0366     dev_state   = pasid_state->device_state;
0367 
0368     if ((start ^ (end - 1)) < PAGE_SIZE)
0369         amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
0370                      start);
0371     else
0372         amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
0373 }
0374 
0375 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
0376 {
0377     struct pasid_state *pasid_state;
0378     struct device_state *dev_state;
0379     bool run_inv_ctx_cb;
0380 
0381     might_sleep();
0382 
0383     pasid_state    = mn_to_state(mn);
0384     dev_state      = pasid_state->device_state;
0385     run_inv_ctx_cb = !pasid_state->invalid;
0386 
0387     if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
0388         dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
0389 
0390     unbind_pasid(pasid_state);
0391 }
0392 
0393 static const struct mmu_notifier_ops iommu_mn = {
0394     .release        = mn_release,
0395     .invalidate_range       = mn_invalidate_range,
0396 };
0397 
0398 static void set_pri_tag_status(struct pasid_state *pasid_state,
0399                    u16 tag, int status)
0400 {
0401     unsigned long flags;
0402 
0403     spin_lock_irqsave(&pasid_state->lock, flags);
0404     pasid_state->pri[tag].status = status;
0405     spin_unlock_irqrestore(&pasid_state->lock, flags);
0406 }
0407 
0408 static void finish_pri_tag(struct device_state *dev_state,
0409                struct pasid_state *pasid_state,
0410                u16 tag)
0411 {
0412     unsigned long flags;
0413 
0414     spin_lock_irqsave(&pasid_state->lock, flags);
0415     if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
0416         pasid_state->pri[tag].finish) {
0417         amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
0418                        pasid_state->pri[tag].status, tag);
0419         pasid_state->pri[tag].finish = false;
0420         pasid_state->pri[tag].status = PPR_SUCCESS;
0421     }
0422     spin_unlock_irqrestore(&pasid_state->lock, flags);
0423 }
0424 
0425 static void handle_fault_error(struct fault *fault)
0426 {
0427     int status;
0428 
0429     if (!fault->dev_state->inv_ppr_cb) {
0430         set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
0431         return;
0432     }
0433 
0434     status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
0435                           fault->pasid,
0436                           fault->address,
0437                           fault->flags);
0438     switch (status) {
0439     case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
0440         set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
0441         break;
0442     case AMD_IOMMU_INV_PRI_RSP_INVALID:
0443         set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
0444         break;
0445     case AMD_IOMMU_INV_PRI_RSP_FAIL:
0446         set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
0447         break;
0448     default:
0449         BUG();
0450     }
0451 }
0452 
0453 static bool access_error(struct vm_area_struct *vma, struct fault *fault)
0454 {
0455     unsigned long requested = 0;
0456 
0457     if (fault->flags & PPR_FAULT_EXEC)
0458         requested |= VM_EXEC;
0459 
0460     if (fault->flags & PPR_FAULT_READ)
0461         requested |= VM_READ;
0462 
0463     if (fault->flags & PPR_FAULT_WRITE)
0464         requested |= VM_WRITE;
0465 
0466     return (requested & ~vma->vm_flags) != 0;
0467 }
0468 
0469 static void do_fault(struct work_struct *work)
0470 {
0471     struct fault *fault = container_of(work, struct fault, work);
0472     struct vm_area_struct *vma;
0473     vm_fault_t ret = VM_FAULT_ERROR;
0474     unsigned int flags = 0;
0475     struct mm_struct *mm;
0476     u64 address;
0477 
0478     mm = fault->state->mm;
0479     address = fault->address;
0480 
0481     if (fault->flags & PPR_FAULT_USER)
0482         flags |= FAULT_FLAG_USER;
0483     if (fault->flags & PPR_FAULT_WRITE)
0484         flags |= FAULT_FLAG_WRITE;
0485     flags |= FAULT_FLAG_REMOTE;
0486 
0487     mmap_read_lock(mm);
0488     vma = find_extend_vma(mm, address);
0489     if (!vma || address < vma->vm_start)
0490         /* failed to get a vma in the right range */
0491         goto out;
0492 
0493     /* Check if we have the right permissions on the vma */
0494     if (access_error(vma, fault))
0495         goto out;
0496 
0497     ret = handle_mm_fault(vma, address, flags, NULL);
0498 out:
0499     mmap_read_unlock(mm);
0500 
0501     if (ret & VM_FAULT_ERROR)
0502         /* failed to service fault */
0503         handle_fault_error(fault);
0504 
0505     finish_pri_tag(fault->dev_state, fault->state, fault->tag);
0506 
0507     put_pasid_state(fault->state);
0508 
0509     kfree(fault);
0510 }
0511 
0512 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
0513 {
0514     struct amd_iommu_fault *iommu_fault;
0515     struct pasid_state *pasid_state;
0516     struct device_state *dev_state;
0517     struct pci_dev *pdev = NULL;
0518     unsigned long flags;
0519     struct fault *fault;
0520     bool finish;
0521     u16 tag, devid, seg_id;
0522     int ret;
0523 
0524     iommu_fault = data;
0525     tag         = iommu_fault->tag & 0x1ff;
0526     finish      = (iommu_fault->tag >> 9) & 1;
0527 
0528     seg_id = PCI_SBDF_TO_SEGID(iommu_fault->sbdf);
0529     devid = PCI_SBDF_TO_DEVID(iommu_fault->sbdf);
0530     pdev = pci_get_domain_bus_and_slot(seg_id, PCI_BUS_NUM(devid),
0531                        devid & 0xff);
0532     if (!pdev)
0533         return -ENODEV;
0534 
0535     ret = NOTIFY_DONE;
0536 
0537     /* In kdump kernel pci dev is not initialized yet -> send INVALID */
0538     if (amd_iommu_is_attach_deferred(&pdev->dev)) {
0539         amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
0540                        PPR_INVALID, tag);
0541         goto out;
0542     }
0543 
0544     dev_state = get_device_state(iommu_fault->sbdf);
0545     if (dev_state == NULL)
0546         goto out;
0547 
0548     pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
0549     if (pasid_state == NULL || pasid_state->invalid) {
0550         /* We know the device but not the PASID -> send INVALID */
0551         amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
0552                        PPR_INVALID, tag);
0553         goto out_drop_state;
0554     }
0555 
0556     spin_lock_irqsave(&pasid_state->lock, flags);
0557     atomic_inc(&pasid_state->pri[tag].inflight);
0558     if (finish)
0559         pasid_state->pri[tag].finish = true;
0560     spin_unlock_irqrestore(&pasid_state->lock, flags);
0561 
0562     fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
0563     if (fault == NULL) {
0564         /* We are OOM - send success and let the device re-fault */
0565         finish_pri_tag(dev_state, pasid_state, tag);
0566         goto out_drop_state;
0567     }
0568 
0569     fault->dev_state = dev_state;
0570     fault->address   = iommu_fault->address;
0571     fault->state     = pasid_state;
0572     fault->tag       = tag;
0573     fault->finish    = finish;
0574     fault->pasid     = iommu_fault->pasid;
0575     fault->flags     = iommu_fault->flags;
0576     INIT_WORK(&fault->work, do_fault);
0577 
0578     queue_work(iommu_wq, &fault->work);
0579 
0580     ret = NOTIFY_OK;
0581 
0582 out_drop_state:
0583 
0584     if (ret != NOTIFY_OK && pasid_state)
0585         put_pasid_state(pasid_state);
0586 
0587     put_device_state(dev_state);
0588 
0589 out:
0590     return ret;
0591 }
0592 
0593 static struct notifier_block ppr_nb = {
0594     .notifier_call = ppr_notifier,
0595 };
0596 
0597 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
0598              struct task_struct *task)
0599 {
0600     struct pasid_state *pasid_state;
0601     struct device_state *dev_state;
0602     struct mm_struct *mm;
0603     u32 sbdf;
0604     int ret;
0605 
0606     might_sleep();
0607 
0608     if (!amd_iommu_v2_supported())
0609         return -ENODEV;
0610 
0611     sbdf      = get_pci_sbdf_id(pdev);
0612     dev_state = get_device_state(sbdf);
0613 
0614     if (dev_state == NULL)
0615         return -EINVAL;
0616 
0617     ret = -EINVAL;
0618     if (pasid >= dev_state->max_pasids)
0619         goto out;
0620 
0621     ret = -ENOMEM;
0622     pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
0623     if (pasid_state == NULL)
0624         goto out;
0625 
0626 
0627     refcount_set(&pasid_state->count, 1);
0628     init_waitqueue_head(&pasid_state->wq);
0629     spin_lock_init(&pasid_state->lock);
0630 
0631     mm                        = get_task_mm(task);
0632     pasid_state->mm           = mm;
0633     pasid_state->device_state = dev_state;
0634     pasid_state->pasid        = pasid;
0635     pasid_state->invalid      = true; /* Mark as valid only if we are
0636                          done with setting up the pasid */
0637     pasid_state->mn.ops       = &iommu_mn;
0638 
0639     if (pasid_state->mm == NULL)
0640         goto out_free;
0641 
0642     mmu_notifier_register(&pasid_state->mn, mm);
0643 
0644     ret = set_pasid_state(dev_state, pasid_state, pasid);
0645     if (ret)
0646         goto out_unregister;
0647 
0648     ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
0649                     __pa(pasid_state->mm->pgd));
0650     if (ret)
0651         goto out_clear_state;
0652 
0653     /* Now we are ready to handle faults */
0654     pasid_state->invalid = false;
0655 
0656     /*
0657      * Drop the reference to the mm_struct here. We rely on the
0658      * mmu_notifier release call-back to inform us when the mm
0659      * is going away.
0660      */
0661     mmput(mm);
0662 
0663     return 0;
0664 
0665 out_clear_state:
0666     clear_pasid_state(dev_state, pasid);
0667 
0668 out_unregister:
0669     mmu_notifier_unregister(&pasid_state->mn, mm);
0670     mmput(mm);
0671 
0672 out_free:
0673     free_pasid_state(pasid_state);
0674 
0675 out:
0676     put_device_state(dev_state);
0677 
0678     return ret;
0679 }
0680 EXPORT_SYMBOL(amd_iommu_bind_pasid);
0681 
0682 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
0683 {
0684     struct pasid_state *pasid_state;
0685     struct device_state *dev_state;
0686     u32 sbdf;
0687 
0688     might_sleep();
0689 
0690     if (!amd_iommu_v2_supported())
0691         return;
0692 
0693     sbdf = get_pci_sbdf_id(pdev);
0694     dev_state = get_device_state(sbdf);
0695     if (dev_state == NULL)
0696         return;
0697 
0698     if (pasid >= dev_state->max_pasids)
0699         goto out;
0700 
0701     pasid_state = get_pasid_state(dev_state, pasid);
0702     if (pasid_state == NULL)
0703         goto out;
0704     /*
0705      * Drop reference taken here. We are safe because we still hold
0706      * the reference taken in the amd_iommu_bind_pasid function.
0707      */
0708     put_pasid_state(pasid_state);
0709 
0710     /* Clear the pasid state so that the pasid can be re-used */
0711     clear_pasid_state(dev_state, pasid_state->pasid);
0712 
0713     /*
0714      * Call mmu_notifier_unregister to drop our reference
0715      * to pasid_state->mm
0716      */
0717     mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
0718 
0719     put_pasid_state_wait(pasid_state); /* Reference taken in
0720                           amd_iommu_bind_pasid */
0721 out:
0722     /* Drop reference taken in this function */
0723     put_device_state(dev_state);
0724 
0725     /* Drop reference taken in amd_iommu_bind_pasid */
0726     put_device_state(dev_state);
0727 }
0728 EXPORT_SYMBOL(amd_iommu_unbind_pasid);
0729 
0730 int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
0731 {
0732     struct device_state *dev_state;
0733     struct iommu_group *group;
0734     unsigned long flags;
0735     int ret, tmp;
0736     u32 sbdf;
0737 
0738     might_sleep();
0739 
0740     /*
0741      * When memory encryption is active the device is likely not in a
0742      * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
0743      */
0744     if (cc_platform_has(CC_ATTR_MEM_ENCRYPT))
0745         return -ENODEV;
0746 
0747     if (!amd_iommu_v2_supported())
0748         return -ENODEV;
0749 
0750     if (pasids <= 0 || pasids > (PASID_MASK + 1))
0751         return -EINVAL;
0752 
0753     sbdf = get_pci_sbdf_id(pdev);
0754 
0755     dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
0756     if (dev_state == NULL)
0757         return -ENOMEM;
0758 
0759     spin_lock_init(&dev_state->lock);
0760     init_waitqueue_head(&dev_state->wq);
0761     dev_state->pdev  = pdev;
0762     dev_state->sbdf = sbdf;
0763 
0764     tmp = pasids;
0765     for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
0766         dev_state->pasid_levels += 1;
0767 
0768     atomic_set(&dev_state->count, 1);
0769     dev_state->max_pasids = pasids;
0770 
0771     ret = -ENOMEM;
0772     dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
0773     if (dev_state->states == NULL)
0774         goto out_free_dev_state;
0775 
0776     dev_state->domain = iommu_domain_alloc(&pci_bus_type);
0777     if (dev_state->domain == NULL)
0778         goto out_free_states;
0779 
0780     /* See iommu_is_default_domain() */
0781     dev_state->domain->type = IOMMU_DOMAIN_IDENTITY;
0782     amd_iommu_domain_direct_map(dev_state->domain);
0783 
0784     ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
0785     if (ret)
0786         goto out_free_domain;
0787 
0788     group = iommu_group_get(&pdev->dev);
0789     if (!group) {
0790         ret = -EINVAL;
0791         goto out_free_domain;
0792     }
0793 
0794     ret = iommu_attach_group(dev_state->domain, group);
0795     if (ret != 0)
0796         goto out_drop_group;
0797 
0798     iommu_group_put(group);
0799 
0800     spin_lock_irqsave(&state_lock, flags);
0801 
0802     if (__get_device_state(sbdf) != NULL) {
0803         spin_unlock_irqrestore(&state_lock, flags);
0804         ret = -EBUSY;
0805         goto out_free_domain;
0806     }
0807 
0808     list_add_tail(&dev_state->list, &state_list);
0809 
0810     spin_unlock_irqrestore(&state_lock, flags);
0811 
0812     return 0;
0813 
0814 out_drop_group:
0815     iommu_group_put(group);
0816 
0817 out_free_domain:
0818     iommu_domain_free(dev_state->domain);
0819 
0820 out_free_states:
0821     free_page((unsigned long)dev_state->states);
0822 
0823 out_free_dev_state:
0824     kfree(dev_state);
0825 
0826     return ret;
0827 }
0828 EXPORT_SYMBOL(amd_iommu_init_device);
0829 
0830 void amd_iommu_free_device(struct pci_dev *pdev)
0831 {
0832     struct device_state *dev_state;
0833     unsigned long flags;
0834     u32 sbdf;
0835 
0836     if (!amd_iommu_v2_supported())
0837         return;
0838 
0839     sbdf = get_pci_sbdf_id(pdev);
0840 
0841     spin_lock_irqsave(&state_lock, flags);
0842 
0843     dev_state = __get_device_state(sbdf);
0844     if (dev_state == NULL) {
0845         spin_unlock_irqrestore(&state_lock, flags);
0846         return;
0847     }
0848 
0849     list_del(&dev_state->list);
0850 
0851     spin_unlock_irqrestore(&state_lock, flags);
0852 
0853     put_device_state(dev_state);
0854     free_device_state(dev_state);
0855 }
0856 EXPORT_SYMBOL(amd_iommu_free_device);
0857 
0858 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
0859                  amd_iommu_invalid_ppr_cb cb)
0860 {
0861     struct device_state *dev_state;
0862     unsigned long flags;
0863     u32 sbdf;
0864     int ret;
0865 
0866     if (!amd_iommu_v2_supported())
0867         return -ENODEV;
0868 
0869     sbdf = get_pci_sbdf_id(pdev);
0870 
0871     spin_lock_irqsave(&state_lock, flags);
0872 
0873     ret = -EINVAL;
0874     dev_state = __get_device_state(sbdf);
0875     if (dev_state == NULL)
0876         goto out_unlock;
0877 
0878     dev_state->inv_ppr_cb = cb;
0879 
0880     ret = 0;
0881 
0882 out_unlock:
0883     spin_unlock_irqrestore(&state_lock, flags);
0884 
0885     return ret;
0886 }
0887 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
0888 
0889 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
0890                     amd_iommu_invalidate_ctx cb)
0891 {
0892     struct device_state *dev_state;
0893     unsigned long flags;
0894     u32 sbdf;
0895     int ret;
0896 
0897     if (!amd_iommu_v2_supported())
0898         return -ENODEV;
0899 
0900     sbdf = get_pci_sbdf_id(pdev);
0901 
0902     spin_lock_irqsave(&state_lock, flags);
0903 
0904     ret = -EINVAL;
0905     dev_state = __get_device_state(sbdf);
0906     if (dev_state == NULL)
0907         goto out_unlock;
0908 
0909     dev_state->inv_ctx_cb = cb;
0910 
0911     ret = 0;
0912 
0913 out_unlock:
0914     spin_unlock_irqrestore(&state_lock, flags);
0915 
0916     return ret;
0917 }
0918 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
0919 
0920 static int __init amd_iommu_v2_init(void)
0921 {
0922     int ret;
0923 
0924     if (!amd_iommu_v2_supported()) {
0925         pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n");
0926         /*
0927          * Load anyway to provide the symbols to other modules
0928          * which may use AMD IOMMUv2 optionally.
0929          */
0930         return 0;
0931     }
0932 
0933     ret = -ENOMEM;
0934     iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
0935     if (iommu_wq == NULL)
0936         goto out;
0937 
0938     amd_iommu_register_ppr_notifier(&ppr_nb);
0939 
0940     pr_info("AMD IOMMUv2 loaded and initialized\n");
0941 
0942     return 0;
0943 
0944 out:
0945     return ret;
0946 }
0947 
0948 static void __exit amd_iommu_v2_exit(void)
0949 {
0950     struct device_state *dev_state, *next;
0951     unsigned long flags;
0952     LIST_HEAD(freelist);
0953 
0954     if (!amd_iommu_v2_supported())
0955         return;
0956 
0957     amd_iommu_unregister_ppr_notifier(&ppr_nb);
0958 
0959     flush_workqueue(iommu_wq);
0960 
0961     /*
0962      * The loop below might call flush_workqueue(), so call
0963      * destroy_workqueue() after it
0964      */
0965     spin_lock_irqsave(&state_lock, flags);
0966 
0967     list_for_each_entry_safe(dev_state, next, &state_list, list) {
0968         WARN_ON_ONCE(1);
0969 
0970         put_device_state(dev_state);
0971         list_del(&dev_state->list);
0972         list_add_tail(&dev_state->list, &freelist);
0973     }
0974 
0975     spin_unlock_irqrestore(&state_lock, flags);
0976 
0977     /*
0978      * Since free_device_state waits on the count to be zero,
0979      * we need to free dev_state outside the spinlock.
0980      */
0981     list_for_each_entry_safe(dev_state, next, &freelist, list) {
0982         list_del(&dev_state->list);
0983         free_device_state(dev_state);
0984     }
0985 
0986     destroy_workqueue(iommu_wq);
0987 }
0988 
0989 module_init(amd_iommu_v2_init);
0990 module_exit(amd_iommu_v2_exit);