0001
0002
0003
0004 #undef pr_fmt
0005 #define pr_fmt(fmt) "tdx: " fmt
0006
0007 #include <linux/cpufeature.h>
0008 #include <asm/coco.h>
0009 #include <asm/tdx.h>
0010 #include <asm/vmx.h>
0011 #include <asm/insn.h>
0012 #include <asm/insn-eval.h>
0013 #include <asm/pgtable.h>
0014
0015
0016 #define TDX_GET_INFO 1
0017 #define TDX_GET_VEINFO 3
0018 #define TDX_ACCEPT_PAGE 6
0019
0020
0021 #define TDVMCALL_MAP_GPA 0x10001
0022
0023
0024 #define EPT_READ 0
0025 #define EPT_WRITE 1
0026
0027
0028 #define PORT_READ 0
0029 #define PORT_WRITE 1
0030
0031
0032 #define VE_IS_IO_IN(e) ((e) & BIT(3))
0033 #define VE_GET_IO_SIZE(e) (((e) & GENMASK(2, 0)) + 1)
0034 #define VE_GET_PORT_NUM(e) ((e) >> 16)
0035 #define VE_IS_IO_STRING(e) ((e) & BIT(4))
0036
0037
0038
0039
0040
0041 static inline u64 _tdx_hypercall(u64 fn, u64 r12, u64 r13, u64 r14, u64 r15)
0042 {
0043 struct tdx_hypercall_args args = {
0044 .r10 = TDX_HYPERCALL_STANDARD,
0045 .r11 = fn,
0046 .r12 = r12,
0047 .r13 = r13,
0048 .r14 = r14,
0049 .r15 = r15,
0050 };
0051
0052 return __tdx_hypercall(&args, 0);
0053 }
0054
0055
0056 void __tdx_hypercall_failed(void)
0057 {
0058 panic("TDVMCALL failed. TDX module bug?");
0059 }
0060
0061
0062
0063
0064
0065
0066
0067 static u64 hcall_func(u64 exit_reason)
0068 {
0069 return exit_reason;
0070 }
0071
0072 #ifdef CONFIG_KVM_GUEST
0073 long tdx_kvm_hypercall(unsigned int nr, unsigned long p1, unsigned long p2,
0074 unsigned long p3, unsigned long p4)
0075 {
0076 struct tdx_hypercall_args args = {
0077 .r10 = nr,
0078 .r11 = p1,
0079 .r12 = p2,
0080 .r13 = p3,
0081 .r14 = p4,
0082 };
0083
0084 return __tdx_hypercall(&args, 0);
0085 }
0086 EXPORT_SYMBOL_GPL(tdx_kvm_hypercall);
0087 #endif
0088
0089
0090
0091
0092
0093
0094 static inline void tdx_module_call(u64 fn, u64 rcx, u64 rdx, u64 r8, u64 r9,
0095 struct tdx_module_output *out)
0096 {
0097 if (__tdx_module_call(fn, rcx, rdx, r8, r9, out))
0098 panic("TDCALL %lld failed (Buggy TDX module!)\n", fn);
0099 }
0100
0101 static u64 get_cc_mask(void)
0102 {
0103 struct tdx_module_output out;
0104 unsigned int gpa_width;
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116 tdx_module_call(TDX_GET_INFO, 0, 0, 0, 0, &out);
0117
0118 gpa_width = out.rcx & GENMASK(5, 0);
0119
0120
0121
0122
0123
0124 return BIT_ULL(gpa_width - 1);
0125 }
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148 static int ve_instr_len(struct ve_info *ve)
0149 {
0150 switch (ve->exit_reason) {
0151 case EXIT_REASON_HLT:
0152 case EXIT_REASON_MSR_READ:
0153 case EXIT_REASON_MSR_WRITE:
0154 case EXIT_REASON_CPUID:
0155 case EXIT_REASON_IO_INSTRUCTION:
0156
0157 return ve->instr_len;
0158 case EXIT_REASON_EPT_VIOLATION:
0159
0160
0161
0162
0163
0164 WARN_ONCE(1, "ve->instr_len is not defined for EPT violations");
0165 return 0;
0166 default:
0167 WARN_ONCE(1, "Unexpected #VE-type: %lld\n", ve->exit_reason);
0168 return ve->instr_len;
0169 }
0170 }
0171
0172 static u64 __cpuidle __halt(const bool irq_disabled, const bool do_sti)
0173 {
0174 struct tdx_hypercall_args args = {
0175 .r10 = TDX_HYPERCALL_STANDARD,
0176 .r11 = hcall_func(EXIT_REASON_HLT),
0177 .r12 = irq_disabled,
0178 };
0179
0180
0181
0182
0183
0184
0185
0186
0187
0188
0189
0190
0191
0192 return __tdx_hypercall(&args, do_sti ? TDX_HCALL_ISSUE_STI : 0);
0193 }
0194
0195 static int handle_halt(struct ve_info *ve)
0196 {
0197
0198
0199
0200
0201
0202 const bool irq_disabled = irqs_disabled();
0203 const bool do_sti = false;
0204
0205 if (__halt(irq_disabled, do_sti))
0206 return -EIO;
0207
0208 return ve_instr_len(ve);
0209 }
0210
0211 void __cpuidle tdx_safe_halt(void)
0212 {
0213
0214
0215
0216
0217
0218 const bool irq_disabled = false;
0219 const bool do_sti = true;
0220
0221
0222
0223
0224 if (__halt(irq_disabled, do_sti))
0225 WARN_ONCE(1, "HLT instruction emulation failed\n");
0226 }
0227
0228 static int read_msr(struct pt_regs *regs, struct ve_info *ve)
0229 {
0230 struct tdx_hypercall_args args = {
0231 .r10 = TDX_HYPERCALL_STANDARD,
0232 .r11 = hcall_func(EXIT_REASON_MSR_READ),
0233 .r12 = regs->cx,
0234 };
0235
0236
0237
0238
0239
0240
0241 if (__tdx_hypercall(&args, TDX_HCALL_HAS_OUTPUT))
0242 return -EIO;
0243
0244 regs->ax = lower_32_bits(args.r11);
0245 regs->dx = upper_32_bits(args.r11);
0246 return ve_instr_len(ve);
0247 }
0248
0249 static int write_msr(struct pt_regs *regs, struct ve_info *ve)
0250 {
0251 struct tdx_hypercall_args args = {
0252 .r10 = TDX_HYPERCALL_STANDARD,
0253 .r11 = hcall_func(EXIT_REASON_MSR_WRITE),
0254 .r12 = regs->cx,
0255 .r13 = (u64)regs->dx << 32 | regs->ax,
0256 };
0257
0258
0259
0260
0261
0262
0263 if (__tdx_hypercall(&args, 0))
0264 return -EIO;
0265
0266 return ve_instr_len(ve);
0267 }
0268
0269 static int handle_cpuid(struct pt_regs *regs, struct ve_info *ve)
0270 {
0271 struct tdx_hypercall_args args = {
0272 .r10 = TDX_HYPERCALL_STANDARD,
0273 .r11 = hcall_func(EXIT_REASON_CPUID),
0274 .r12 = regs->ax,
0275 .r13 = regs->cx,
0276 };
0277
0278
0279
0280
0281
0282
0283
0284
0285 if (regs->ax < 0x40000000 || regs->ax > 0x4FFFFFFF) {
0286 regs->ax = regs->bx = regs->cx = regs->dx = 0;
0287 return ve_instr_len(ve);
0288 }
0289
0290
0291
0292
0293
0294
0295 if (__tdx_hypercall(&args, TDX_HCALL_HAS_OUTPUT))
0296 return -EIO;
0297
0298
0299
0300
0301
0302
0303 regs->ax = args.r12;
0304 regs->bx = args.r13;
0305 regs->cx = args.r14;
0306 regs->dx = args.r15;
0307
0308 return ve_instr_len(ve);
0309 }
0310
0311 static bool mmio_read(int size, unsigned long addr, unsigned long *val)
0312 {
0313 struct tdx_hypercall_args args = {
0314 .r10 = TDX_HYPERCALL_STANDARD,
0315 .r11 = hcall_func(EXIT_REASON_EPT_VIOLATION),
0316 .r12 = size,
0317 .r13 = EPT_READ,
0318 .r14 = addr,
0319 .r15 = *val,
0320 };
0321
0322 if (__tdx_hypercall(&args, TDX_HCALL_HAS_OUTPUT))
0323 return false;
0324 *val = args.r11;
0325 return true;
0326 }
0327
0328 static bool mmio_write(int size, unsigned long addr, unsigned long val)
0329 {
0330 return !_tdx_hypercall(hcall_func(EXIT_REASON_EPT_VIOLATION), size,
0331 EPT_WRITE, addr, val);
0332 }
0333
0334 static int handle_mmio(struct pt_regs *regs, struct ve_info *ve)
0335 {
0336 unsigned long *reg, val, vaddr;
0337 char buffer[MAX_INSN_SIZE];
0338 struct insn insn = {};
0339 enum mmio_type mmio;
0340 int size, extend_size;
0341 u8 extend_val = 0;
0342
0343
0344 if (WARN_ON_ONCE(user_mode(regs)))
0345 return -EFAULT;
0346
0347 if (copy_from_kernel_nofault(buffer, (void *)regs->ip, MAX_INSN_SIZE))
0348 return -EFAULT;
0349
0350 if (insn_decode(&insn, buffer, MAX_INSN_SIZE, INSN_MODE_64))
0351 return -EINVAL;
0352
0353 mmio = insn_decode_mmio(&insn, &size);
0354 if (WARN_ON_ONCE(mmio == MMIO_DECODE_FAILED))
0355 return -EINVAL;
0356
0357 if (mmio != MMIO_WRITE_IMM && mmio != MMIO_MOVS) {
0358 reg = insn_get_modrm_reg_ptr(&insn, regs);
0359 if (!reg)
0360 return -EINVAL;
0361 }
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372 vaddr = (unsigned long)insn_get_addr_ref(&insn, regs);
0373 if (vaddr / PAGE_SIZE != (vaddr + size - 1) / PAGE_SIZE)
0374 return -EFAULT;
0375
0376
0377 switch (mmio) {
0378 case MMIO_WRITE:
0379 memcpy(&val, reg, size);
0380 if (!mmio_write(size, ve->gpa, val))
0381 return -EIO;
0382 return insn.length;
0383 case MMIO_WRITE_IMM:
0384 val = insn.immediate.value;
0385 if (!mmio_write(size, ve->gpa, val))
0386 return -EIO;
0387 return insn.length;
0388 case MMIO_READ:
0389 case MMIO_READ_ZERO_EXTEND:
0390 case MMIO_READ_SIGN_EXTEND:
0391
0392 break;
0393 case MMIO_MOVS:
0394 case MMIO_DECODE_FAILED:
0395
0396
0397
0398
0399
0400 return -EINVAL;
0401 default:
0402 WARN_ONCE(1, "Unknown insn_decode_mmio() decode value?");
0403 return -EINVAL;
0404 }
0405
0406
0407 if (!mmio_read(size, ve->gpa, &val))
0408 return -EIO;
0409
0410 switch (mmio) {
0411 case MMIO_READ:
0412
0413 extend_size = size == 4 ? sizeof(*reg) : 0;
0414 break;
0415 case MMIO_READ_ZERO_EXTEND:
0416
0417 extend_size = insn.opnd_bytes;
0418 break;
0419 case MMIO_READ_SIGN_EXTEND:
0420
0421 extend_size = insn.opnd_bytes;
0422 if (size == 1 && val & BIT(7))
0423 extend_val = 0xFF;
0424 else if (size > 1 && val & BIT(15))
0425 extend_val = 0xFF;
0426 break;
0427 default:
0428
0429 WARN_ON_ONCE(1);
0430 return -EINVAL;
0431 }
0432
0433 if (extend_size)
0434 memset(reg, extend_val, extend_size);
0435 memcpy(reg, &val, size);
0436 return insn.length;
0437 }
0438
0439 static bool handle_in(struct pt_regs *regs, int size, int port)
0440 {
0441 struct tdx_hypercall_args args = {
0442 .r10 = TDX_HYPERCALL_STANDARD,
0443 .r11 = hcall_func(EXIT_REASON_IO_INSTRUCTION),
0444 .r12 = size,
0445 .r13 = PORT_READ,
0446 .r14 = port,
0447 };
0448 u64 mask = GENMASK(BITS_PER_BYTE * size, 0);
0449 bool success;
0450
0451
0452
0453
0454
0455
0456 success = !__tdx_hypercall(&args, TDX_HCALL_HAS_OUTPUT);
0457
0458
0459 regs->ax &= ~mask;
0460 if (success)
0461 regs->ax |= args.r11 & mask;
0462
0463 return success;
0464 }
0465
0466 static bool handle_out(struct pt_regs *regs, int size, int port)
0467 {
0468 u64 mask = GENMASK(BITS_PER_BYTE * size, 0);
0469
0470
0471
0472
0473
0474
0475 return !_tdx_hypercall(hcall_func(EXIT_REASON_IO_INSTRUCTION), size,
0476 PORT_WRITE, port, regs->ax & mask);
0477 }
0478
0479
0480
0481
0482
0483
0484
0485
0486
0487 static int handle_io(struct pt_regs *regs, struct ve_info *ve)
0488 {
0489 u32 exit_qual = ve->exit_qual;
0490 int size, port;
0491 bool in, ret;
0492
0493 if (VE_IS_IO_STRING(exit_qual))
0494 return -EIO;
0495
0496 in = VE_IS_IO_IN(exit_qual);
0497 size = VE_GET_IO_SIZE(exit_qual);
0498 port = VE_GET_PORT_NUM(exit_qual);
0499
0500
0501 if (in)
0502 ret = handle_in(regs, size, port);
0503 else
0504 ret = handle_out(regs, size, port);
0505 if (!ret)
0506 return -EIO;
0507
0508 return ve_instr_len(ve);
0509 }
0510
0511
0512
0513
0514
0515 __init bool tdx_early_handle_ve(struct pt_regs *regs)
0516 {
0517 struct ve_info ve;
0518 int insn_len;
0519
0520 tdx_get_ve_info(&ve);
0521
0522 if (ve.exit_reason != EXIT_REASON_IO_INSTRUCTION)
0523 return false;
0524
0525 insn_len = handle_io(regs, &ve);
0526 if (insn_len < 0)
0527 return false;
0528
0529 regs->ip += insn_len;
0530 return true;
0531 }
0532
0533 void tdx_get_ve_info(struct ve_info *ve)
0534 {
0535 struct tdx_module_output out;
0536
0537
0538
0539
0540
0541
0542
0543
0544
0545
0546
0547
0548
0549
0550
0551
0552 tdx_module_call(TDX_GET_VEINFO, 0, 0, 0, 0, &out);
0553
0554
0555 ve->exit_reason = out.rcx;
0556 ve->exit_qual = out.rdx;
0557 ve->gla = out.r8;
0558 ve->gpa = out.r9;
0559 ve->instr_len = lower_32_bits(out.r10);
0560 ve->instr_info = upper_32_bits(out.r10);
0561 }
0562
0563
0564
0565
0566
0567
0568
0569 static int virt_exception_user(struct pt_regs *regs, struct ve_info *ve)
0570 {
0571 switch (ve->exit_reason) {
0572 case EXIT_REASON_CPUID:
0573 return handle_cpuid(regs, ve);
0574 default:
0575 pr_warn("Unexpected #VE: %lld\n", ve->exit_reason);
0576 return -EIO;
0577 }
0578 }
0579
0580
0581
0582
0583
0584
0585
0586 static int virt_exception_kernel(struct pt_regs *regs, struct ve_info *ve)
0587 {
0588 switch (ve->exit_reason) {
0589 case EXIT_REASON_HLT:
0590 return handle_halt(ve);
0591 case EXIT_REASON_MSR_READ:
0592 return read_msr(regs, ve);
0593 case EXIT_REASON_MSR_WRITE:
0594 return write_msr(regs, ve);
0595 case EXIT_REASON_CPUID:
0596 return handle_cpuid(regs, ve);
0597 case EXIT_REASON_EPT_VIOLATION:
0598 return handle_mmio(regs, ve);
0599 case EXIT_REASON_IO_INSTRUCTION:
0600 return handle_io(regs, ve);
0601 default:
0602 pr_warn("Unexpected #VE: %lld\n", ve->exit_reason);
0603 return -EIO;
0604 }
0605 }
0606
0607 bool tdx_handle_virt_exception(struct pt_regs *regs, struct ve_info *ve)
0608 {
0609 int insn_len;
0610
0611 if (user_mode(regs))
0612 insn_len = virt_exception_user(regs, ve);
0613 else
0614 insn_len = virt_exception_kernel(regs, ve);
0615 if (insn_len < 0)
0616 return false;
0617
0618
0619 regs->ip += insn_len;
0620
0621 return true;
0622 }
0623
0624 static bool tdx_tlb_flush_required(bool private)
0625 {
0626
0627
0628
0629
0630
0631
0632
0633
0634
0635
0636
0637
0638
0639 return !private;
0640 }
0641
0642 static bool tdx_cache_flush_required(void)
0643 {
0644
0645
0646
0647
0648
0649
0650 return true;
0651 }
0652
0653 static bool try_accept_one(phys_addr_t *start, unsigned long len,
0654 enum pg_level pg_level)
0655 {
0656 unsigned long accept_size = page_level_size(pg_level);
0657 u64 tdcall_rcx;
0658 u8 page_size;
0659
0660 if (!IS_ALIGNED(*start, accept_size))
0661 return false;
0662
0663 if (len < accept_size)
0664 return false;
0665
0666
0667
0668
0669
0670
0671
0672 switch (pg_level) {
0673 case PG_LEVEL_4K:
0674 page_size = 0;
0675 break;
0676 case PG_LEVEL_2M:
0677 page_size = 1;
0678 break;
0679 case PG_LEVEL_1G:
0680 page_size = 2;
0681 break;
0682 default:
0683 return false;
0684 }
0685
0686 tdcall_rcx = *start | page_size;
0687 if (__tdx_module_call(TDX_ACCEPT_PAGE, tdcall_rcx, 0, 0, 0, NULL))
0688 return false;
0689
0690 *start += accept_size;
0691 return true;
0692 }
0693
0694
0695
0696
0697
0698
0699 static bool tdx_enc_status_changed(unsigned long vaddr, int numpages, bool enc)
0700 {
0701 phys_addr_t start = __pa(vaddr);
0702 phys_addr_t end = __pa(vaddr + numpages * PAGE_SIZE);
0703
0704 if (!enc) {
0705
0706 start |= cc_mkdec(0);
0707 end |= cc_mkdec(0);
0708 }
0709
0710
0711
0712
0713
0714
0715 if (_tdx_hypercall(TDVMCALL_MAP_GPA, start, end - start, 0, 0))
0716 return false;
0717
0718
0719 if (!enc)
0720 return true;
0721
0722
0723
0724
0725
0726 while (start < end) {
0727 unsigned long len = end - start;
0728
0729
0730
0731
0732
0733
0734
0735 if (try_accept_one(&start, len, PG_LEVEL_1G))
0736 continue;
0737
0738 if (try_accept_one(&start, len, PG_LEVEL_2M))
0739 continue;
0740
0741 if (!try_accept_one(&start, len, PG_LEVEL_4K))
0742 return false;
0743 }
0744
0745 return true;
0746 }
0747
0748 void __init tdx_early_init(void)
0749 {
0750 u64 cc_mask;
0751 u32 eax, sig[3];
0752
0753 cpuid_count(TDX_CPUID_LEAF_ID, 0, &eax, &sig[0], &sig[2], &sig[1]);
0754
0755 if (memcmp(TDX_IDENT, sig, sizeof(sig)))
0756 return;
0757
0758 setup_force_cpu_cap(X86_FEATURE_TDX_GUEST);
0759
0760 cc_set_vendor(CC_VENDOR_INTEL);
0761 cc_mask = get_cc_mask();
0762 cc_set_mask(cc_mask);
0763
0764
0765
0766
0767
0768
0769
0770 physical_mask &= cc_mask - 1;
0771
0772 x86_platform.guest.enc_cache_flush_required = tdx_cache_flush_required;
0773 x86_platform.guest.enc_tlb_flush_required = tdx_tlb_flush_required;
0774 x86_platform.guest.enc_status_change_finish = tdx_enc_status_changed;
0775
0776 pr_info("Guest detected\n");
0777 }