Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * aes-ce-glue.c - wrapper code for ARMv8 AES
0004  *
0005  * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
0006  */
0007 
0008 #include <asm/hwcap.h>
0009 #include <asm/neon.h>
0010 #include <asm/simd.h>
0011 #include <asm/unaligned.h>
0012 #include <crypto/aes.h>
0013 #include <crypto/ctr.h>
0014 #include <crypto/internal/simd.h>
0015 #include <crypto/internal/skcipher.h>
0016 #include <crypto/scatterwalk.h>
0017 #include <linux/cpufeature.h>
0018 #include <linux/module.h>
0019 #include <crypto/xts.h>
0020 
0021 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
0022 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
0023 MODULE_LICENSE("GPL v2");
0024 
0025 /* defined in aes-ce-core.S */
0026 asmlinkage u32 ce_aes_sub(u32 input);
0027 asmlinkage void ce_aes_invert(void *dst, void *src);
0028 
0029 asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
0030                    int rounds, int blocks);
0031 asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
0032                    int rounds, int blocks);
0033 
0034 asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
0035                    int rounds, int blocks, u8 iv[]);
0036 asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
0037                    int rounds, int blocks, u8 iv[]);
0038 asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
0039                    int rounds, int bytes, u8 const iv[]);
0040 asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
0041                    int rounds, int bytes, u8 const iv[]);
0042 
0043 asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
0044                    int rounds, int blocks, u8 ctr[]);
0045 
0046 asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
0047                    int rounds, int bytes, u8 iv[],
0048                    u32 const rk2[], int first);
0049 asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
0050                    int rounds, int bytes, u8 iv[],
0051                    u32 const rk2[], int first);
0052 
0053 struct aes_block {
0054     u8 b[AES_BLOCK_SIZE];
0055 };
0056 
0057 static int num_rounds(struct crypto_aes_ctx *ctx)
0058 {
0059     /*
0060      * # of rounds specified by AES:
0061      * 128 bit key      10 rounds
0062      * 192 bit key      12 rounds
0063      * 256 bit key      14 rounds
0064      * => n byte key    => 6 + (n/4) rounds
0065      */
0066     return 6 + ctx->key_length / 4;
0067 }
0068 
0069 static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
0070                 unsigned int key_len)
0071 {
0072     /*
0073      * The AES key schedule round constants
0074      */
0075     static u8 const rcon[] = {
0076         0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
0077     };
0078 
0079     u32 kwords = key_len / sizeof(u32);
0080     struct aes_block *key_enc, *key_dec;
0081     int i, j;
0082 
0083     if (key_len != AES_KEYSIZE_128 &&
0084         key_len != AES_KEYSIZE_192 &&
0085         key_len != AES_KEYSIZE_256)
0086         return -EINVAL;
0087 
0088     ctx->key_length = key_len;
0089     for (i = 0; i < kwords; i++)
0090         ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
0091 
0092     kernel_neon_begin();
0093     for (i = 0; i < sizeof(rcon); i++) {
0094         u32 *rki = ctx->key_enc + (i * kwords);
0095         u32 *rko = rki + kwords;
0096 
0097         rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
0098         rko[0] = rko[0] ^ rki[0] ^ rcon[i];
0099         rko[1] = rko[0] ^ rki[1];
0100         rko[2] = rko[1] ^ rki[2];
0101         rko[3] = rko[2] ^ rki[3];
0102 
0103         if (key_len == AES_KEYSIZE_192) {
0104             if (i >= 7)
0105                 break;
0106             rko[4] = rko[3] ^ rki[4];
0107             rko[5] = rko[4] ^ rki[5];
0108         } else if (key_len == AES_KEYSIZE_256) {
0109             if (i >= 6)
0110                 break;
0111             rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
0112             rko[5] = rko[4] ^ rki[5];
0113             rko[6] = rko[5] ^ rki[6];
0114             rko[7] = rko[6] ^ rki[7];
0115         }
0116     }
0117 
0118     /*
0119      * Generate the decryption keys for the Equivalent Inverse Cipher.
0120      * This involves reversing the order of the round keys, and applying
0121      * the Inverse Mix Columns transformation on all but the first and
0122      * the last one.
0123      */
0124     key_enc = (struct aes_block *)ctx->key_enc;
0125     key_dec = (struct aes_block *)ctx->key_dec;
0126     j = num_rounds(ctx);
0127 
0128     key_dec[0] = key_enc[j];
0129     for (i = 1, j--; j > 0; i++, j--)
0130         ce_aes_invert(key_dec + i, key_enc + j);
0131     key_dec[i] = key_enc[0];
0132 
0133     kernel_neon_end();
0134     return 0;
0135 }
0136 
0137 static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0138              unsigned int key_len)
0139 {
0140     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0141 
0142     return ce_aes_expandkey(ctx, in_key, key_len);
0143 }
0144 
0145 struct crypto_aes_xts_ctx {
0146     struct crypto_aes_ctx key1;
0147     struct crypto_aes_ctx __aligned(8) key2;
0148 };
0149 
0150 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
0151                unsigned int key_len)
0152 {
0153     struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0154     int ret;
0155 
0156     ret = xts_verify_key(tfm, in_key, key_len);
0157     if (ret)
0158         return ret;
0159 
0160     ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
0161     if (!ret)
0162         ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
0163                        key_len / 2);
0164     return ret;
0165 }
0166 
0167 static int ecb_encrypt(struct skcipher_request *req)
0168 {
0169     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0170     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0171     struct skcipher_walk walk;
0172     unsigned int blocks;
0173     int err;
0174 
0175     err = skcipher_walk_virt(&walk, req, false);
0176 
0177     while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
0178         kernel_neon_begin();
0179         ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0180                    ctx->key_enc, num_rounds(ctx), blocks);
0181         kernel_neon_end();
0182         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0183     }
0184     return err;
0185 }
0186 
0187 static int ecb_decrypt(struct skcipher_request *req)
0188 {
0189     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0190     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0191     struct skcipher_walk walk;
0192     unsigned int blocks;
0193     int err;
0194 
0195     err = skcipher_walk_virt(&walk, req, false);
0196 
0197     while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
0198         kernel_neon_begin();
0199         ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0200                    ctx->key_dec, num_rounds(ctx), blocks);
0201         kernel_neon_end();
0202         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0203     }
0204     return err;
0205 }
0206 
0207 static int cbc_encrypt_walk(struct skcipher_request *req,
0208                 struct skcipher_walk *walk)
0209 {
0210     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0211     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0212     unsigned int blocks;
0213     int err = 0;
0214 
0215     while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
0216         kernel_neon_begin();
0217         ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
0218                    ctx->key_enc, num_rounds(ctx), blocks,
0219                    walk->iv);
0220         kernel_neon_end();
0221         err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
0222     }
0223     return err;
0224 }
0225 
0226 static int cbc_encrypt(struct skcipher_request *req)
0227 {
0228     struct skcipher_walk walk;
0229     int err;
0230 
0231     err = skcipher_walk_virt(&walk, req, false);
0232     if (err)
0233         return err;
0234     return cbc_encrypt_walk(req, &walk);
0235 }
0236 
0237 static int cbc_decrypt_walk(struct skcipher_request *req,
0238                 struct skcipher_walk *walk)
0239 {
0240     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0241     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0242     unsigned int blocks;
0243     int err = 0;
0244 
0245     while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
0246         kernel_neon_begin();
0247         ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
0248                    ctx->key_dec, num_rounds(ctx), blocks,
0249                    walk->iv);
0250         kernel_neon_end();
0251         err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
0252     }
0253     return err;
0254 }
0255 
0256 static int cbc_decrypt(struct skcipher_request *req)
0257 {
0258     struct skcipher_walk walk;
0259     int err;
0260 
0261     err = skcipher_walk_virt(&walk, req, false);
0262     if (err)
0263         return err;
0264     return cbc_decrypt_walk(req, &walk);
0265 }
0266 
0267 static int cts_cbc_encrypt(struct skcipher_request *req)
0268 {
0269     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0270     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0271     int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
0272     struct scatterlist *src = req->src, *dst = req->dst;
0273     struct scatterlist sg_src[2], sg_dst[2];
0274     struct skcipher_request subreq;
0275     struct skcipher_walk walk;
0276     int err;
0277 
0278     skcipher_request_set_tfm(&subreq, tfm);
0279     skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
0280                       NULL, NULL);
0281 
0282     if (req->cryptlen <= AES_BLOCK_SIZE) {
0283         if (req->cryptlen < AES_BLOCK_SIZE)
0284             return -EINVAL;
0285         cbc_blocks = 1;
0286     }
0287 
0288     if (cbc_blocks > 0) {
0289         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0290                        cbc_blocks * AES_BLOCK_SIZE,
0291                        req->iv);
0292 
0293         err = skcipher_walk_virt(&walk, &subreq, false) ?:
0294               cbc_encrypt_walk(&subreq, &walk);
0295         if (err)
0296             return err;
0297 
0298         if (req->cryptlen == AES_BLOCK_SIZE)
0299             return 0;
0300 
0301         dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
0302         if (req->dst != req->src)
0303             dst = scatterwalk_ffwd(sg_dst, req->dst,
0304                            subreq.cryptlen);
0305     }
0306 
0307     /* handle ciphertext stealing */
0308     skcipher_request_set_crypt(&subreq, src, dst,
0309                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
0310                    req->iv);
0311 
0312     err = skcipher_walk_virt(&walk, &subreq, false);
0313     if (err)
0314         return err;
0315 
0316     kernel_neon_begin();
0317     ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0318                    ctx->key_enc, num_rounds(ctx), walk.nbytes,
0319                    walk.iv);
0320     kernel_neon_end();
0321 
0322     return skcipher_walk_done(&walk, 0);
0323 }
0324 
0325 static int cts_cbc_decrypt(struct skcipher_request *req)
0326 {
0327     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0328     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0329     int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
0330     struct scatterlist *src = req->src, *dst = req->dst;
0331     struct scatterlist sg_src[2], sg_dst[2];
0332     struct skcipher_request subreq;
0333     struct skcipher_walk walk;
0334     int err;
0335 
0336     skcipher_request_set_tfm(&subreq, tfm);
0337     skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
0338                       NULL, NULL);
0339 
0340     if (req->cryptlen <= AES_BLOCK_SIZE) {
0341         if (req->cryptlen < AES_BLOCK_SIZE)
0342             return -EINVAL;
0343         cbc_blocks = 1;
0344     }
0345 
0346     if (cbc_blocks > 0) {
0347         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0348                        cbc_blocks * AES_BLOCK_SIZE,
0349                        req->iv);
0350 
0351         err = skcipher_walk_virt(&walk, &subreq, false) ?:
0352               cbc_decrypt_walk(&subreq, &walk);
0353         if (err)
0354             return err;
0355 
0356         if (req->cryptlen == AES_BLOCK_SIZE)
0357             return 0;
0358 
0359         dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
0360         if (req->dst != req->src)
0361             dst = scatterwalk_ffwd(sg_dst, req->dst,
0362                            subreq.cryptlen);
0363     }
0364 
0365     /* handle ciphertext stealing */
0366     skcipher_request_set_crypt(&subreq, src, dst,
0367                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
0368                    req->iv);
0369 
0370     err = skcipher_walk_virt(&walk, &subreq, false);
0371     if (err)
0372         return err;
0373 
0374     kernel_neon_begin();
0375     ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0376                    ctx->key_dec, num_rounds(ctx), walk.nbytes,
0377                    walk.iv);
0378     kernel_neon_end();
0379 
0380     return skcipher_walk_done(&walk, 0);
0381 }
0382 
0383 static int ctr_encrypt(struct skcipher_request *req)
0384 {
0385     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0386     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0387     struct skcipher_walk walk;
0388     int err, blocks;
0389 
0390     err = skcipher_walk_virt(&walk, req, false);
0391 
0392     while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
0393         kernel_neon_begin();
0394         ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0395                    ctx->key_enc, num_rounds(ctx), blocks,
0396                    walk.iv);
0397         kernel_neon_end();
0398         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0399     }
0400     if (walk.nbytes) {
0401         u8 __aligned(8) tail[AES_BLOCK_SIZE];
0402         unsigned int nbytes = walk.nbytes;
0403         u8 *tdst = walk.dst.virt.addr;
0404         u8 *tsrc = walk.src.virt.addr;
0405 
0406         /*
0407          * Tell aes_ctr_encrypt() to process a tail block.
0408          */
0409         blocks = -1;
0410 
0411         kernel_neon_begin();
0412         ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
0413                    blocks, walk.iv);
0414         kernel_neon_end();
0415         crypto_xor_cpy(tdst, tsrc, tail, nbytes);
0416         err = skcipher_walk_done(&walk, 0);
0417     }
0418     return err;
0419 }
0420 
0421 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
0422 {
0423     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0424     unsigned long flags;
0425 
0426     /*
0427      * Temporarily disable interrupts to avoid races where
0428      * cachelines are evicted when the CPU is interrupted
0429      * to do something else.
0430      */
0431     local_irq_save(flags);
0432     aes_encrypt(ctx, dst, src);
0433     local_irq_restore(flags);
0434 }
0435 
0436 static int ctr_encrypt_sync(struct skcipher_request *req)
0437 {
0438     if (!crypto_simd_usable())
0439         return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
0440 
0441     return ctr_encrypt(req);
0442 }
0443 
0444 static int xts_encrypt(struct skcipher_request *req)
0445 {
0446     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0447     struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0448     int err, first, rounds = num_rounds(&ctx->key1);
0449     int tail = req->cryptlen % AES_BLOCK_SIZE;
0450     struct scatterlist sg_src[2], sg_dst[2];
0451     struct skcipher_request subreq;
0452     struct scatterlist *src, *dst;
0453     struct skcipher_walk walk;
0454 
0455     if (req->cryptlen < AES_BLOCK_SIZE)
0456         return -EINVAL;
0457 
0458     err = skcipher_walk_virt(&walk, req, false);
0459 
0460     if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
0461         int xts_blocks = DIV_ROUND_UP(req->cryptlen,
0462                           AES_BLOCK_SIZE) - 2;
0463 
0464         skcipher_walk_abort(&walk);
0465 
0466         skcipher_request_set_tfm(&subreq, tfm);
0467         skcipher_request_set_callback(&subreq,
0468                           skcipher_request_flags(req),
0469                           NULL, NULL);
0470         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0471                        xts_blocks * AES_BLOCK_SIZE,
0472                        req->iv);
0473         req = &subreq;
0474         err = skcipher_walk_virt(&walk, req, false);
0475     } else {
0476         tail = 0;
0477     }
0478 
0479     for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
0480         int nbytes = walk.nbytes;
0481 
0482         if (walk.nbytes < walk.total)
0483             nbytes &= ~(AES_BLOCK_SIZE - 1);
0484 
0485         kernel_neon_begin();
0486         ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0487                    ctx->key1.key_enc, rounds, nbytes, walk.iv,
0488                    ctx->key2.key_enc, first);
0489         kernel_neon_end();
0490         err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0491     }
0492 
0493     if (err || likely(!tail))
0494         return err;
0495 
0496     dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
0497     if (req->dst != req->src)
0498         dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
0499 
0500     skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
0501                    req->iv);
0502 
0503     err = skcipher_walk_virt(&walk, req, false);
0504     if (err)
0505         return err;
0506 
0507     kernel_neon_begin();
0508     ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0509                ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
0510                ctx->key2.key_enc, first);
0511     kernel_neon_end();
0512 
0513     return skcipher_walk_done(&walk, 0);
0514 }
0515 
0516 static int xts_decrypt(struct skcipher_request *req)
0517 {
0518     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0519     struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0520     int err, first, rounds = num_rounds(&ctx->key1);
0521     int tail = req->cryptlen % AES_BLOCK_SIZE;
0522     struct scatterlist sg_src[2], sg_dst[2];
0523     struct skcipher_request subreq;
0524     struct scatterlist *src, *dst;
0525     struct skcipher_walk walk;
0526 
0527     if (req->cryptlen < AES_BLOCK_SIZE)
0528         return -EINVAL;
0529 
0530     err = skcipher_walk_virt(&walk, req, false);
0531 
0532     if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
0533         int xts_blocks = DIV_ROUND_UP(req->cryptlen,
0534                           AES_BLOCK_SIZE) - 2;
0535 
0536         skcipher_walk_abort(&walk);
0537 
0538         skcipher_request_set_tfm(&subreq, tfm);
0539         skcipher_request_set_callback(&subreq,
0540                           skcipher_request_flags(req),
0541                           NULL, NULL);
0542         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0543                        xts_blocks * AES_BLOCK_SIZE,
0544                        req->iv);
0545         req = &subreq;
0546         err = skcipher_walk_virt(&walk, req, false);
0547     } else {
0548         tail = 0;
0549     }
0550 
0551     for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
0552         int nbytes = walk.nbytes;
0553 
0554         if (walk.nbytes < walk.total)
0555             nbytes &= ~(AES_BLOCK_SIZE - 1);
0556 
0557         kernel_neon_begin();
0558         ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0559                    ctx->key1.key_dec, rounds, nbytes, walk.iv,
0560                    ctx->key2.key_enc, first);
0561         kernel_neon_end();
0562         err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0563     }
0564 
0565     if (err || likely(!tail))
0566         return err;
0567 
0568     dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
0569     if (req->dst != req->src)
0570         dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
0571 
0572     skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
0573                    req->iv);
0574 
0575     err = skcipher_walk_virt(&walk, req, false);
0576     if (err)
0577         return err;
0578 
0579     kernel_neon_begin();
0580     ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0581                ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
0582                ctx->key2.key_enc, first);
0583     kernel_neon_end();
0584 
0585     return skcipher_walk_done(&walk, 0);
0586 }
0587 
0588 static struct skcipher_alg aes_algs[] = { {
0589     .base.cra_name      = "__ecb(aes)",
0590     .base.cra_driver_name   = "__ecb-aes-ce",
0591     .base.cra_priority  = 300,
0592     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0593     .base.cra_blocksize = AES_BLOCK_SIZE,
0594     .base.cra_ctxsize   = sizeof(struct crypto_aes_ctx),
0595     .base.cra_module    = THIS_MODULE,
0596 
0597     .min_keysize        = AES_MIN_KEY_SIZE,
0598     .max_keysize        = AES_MAX_KEY_SIZE,
0599     .setkey         = ce_aes_setkey,
0600     .encrypt        = ecb_encrypt,
0601     .decrypt        = ecb_decrypt,
0602 }, {
0603     .base.cra_name      = "__cbc(aes)",
0604     .base.cra_driver_name   = "__cbc-aes-ce",
0605     .base.cra_priority  = 300,
0606     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0607     .base.cra_blocksize = AES_BLOCK_SIZE,
0608     .base.cra_ctxsize   = sizeof(struct crypto_aes_ctx),
0609     .base.cra_module    = THIS_MODULE,
0610 
0611     .min_keysize        = AES_MIN_KEY_SIZE,
0612     .max_keysize        = AES_MAX_KEY_SIZE,
0613     .ivsize         = AES_BLOCK_SIZE,
0614     .setkey         = ce_aes_setkey,
0615     .encrypt        = cbc_encrypt,
0616     .decrypt        = cbc_decrypt,
0617 }, {
0618     .base.cra_name      = "__cts(cbc(aes))",
0619     .base.cra_driver_name   = "__cts-cbc-aes-ce",
0620     .base.cra_priority  = 300,
0621     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0622     .base.cra_blocksize = AES_BLOCK_SIZE,
0623     .base.cra_ctxsize   = sizeof(struct crypto_aes_ctx),
0624     .base.cra_module    = THIS_MODULE,
0625 
0626     .min_keysize        = AES_MIN_KEY_SIZE,
0627     .max_keysize        = AES_MAX_KEY_SIZE,
0628     .ivsize         = AES_BLOCK_SIZE,
0629     .walksize       = 2 * AES_BLOCK_SIZE,
0630     .setkey         = ce_aes_setkey,
0631     .encrypt        = cts_cbc_encrypt,
0632     .decrypt        = cts_cbc_decrypt,
0633 }, {
0634     .base.cra_name      = "__ctr(aes)",
0635     .base.cra_driver_name   = "__ctr-aes-ce",
0636     .base.cra_priority  = 300,
0637     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0638     .base.cra_blocksize = 1,
0639     .base.cra_ctxsize   = sizeof(struct crypto_aes_ctx),
0640     .base.cra_module    = THIS_MODULE,
0641 
0642     .min_keysize        = AES_MIN_KEY_SIZE,
0643     .max_keysize        = AES_MAX_KEY_SIZE,
0644     .ivsize         = AES_BLOCK_SIZE,
0645     .chunksize      = AES_BLOCK_SIZE,
0646     .setkey         = ce_aes_setkey,
0647     .encrypt        = ctr_encrypt,
0648     .decrypt        = ctr_encrypt,
0649 }, {
0650     .base.cra_name      = "ctr(aes)",
0651     .base.cra_driver_name   = "ctr-aes-ce-sync",
0652     .base.cra_priority  = 300 - 1,
0653     .base.cra_blocksize = 1,
0654     .base.cra_ctxsize   = sizeof(struct crypto_aes_ctx),
0655     .base.cra_module    = THIS_MODULE,
0656 
0657     .min_keysize        = AES_MIN_KEY_SIZE,
0658     .max_keysize        = AES_MAX_KEY_SIZE,
0659     .ivsize         = AES_BLOCK_SIZE,
0660     .chunksize      = AES_BLOCK_SIZE,
0661     .setkey         = ce_aes_setkey,
0662     .encrypt        = ctr_encrypt_sync,
0663     .decrypt        = ctr_encrypt_sync,
0664 }, {
0665     .base.cra_name      = "__xts(aes)",
0666     .base.cra_driver_name   = "__xts-aes-ce",
0667     .base.cra_priority  = 300,
0668     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0669     .base.cra_blocksize = AES_BLOCK_SIZE,
0670     .base.cra_ctxsize   = sizeof(struct crypto_aes_xts_ctx),
0671     .base.cra_module    = THIS_MODULE,
0672 
0673     .min_keysize        = 2 * AES_MIN_KEY_SIZE,
0674     .max_keysize        = 2 * AES_MAX_KEY_SIZE,
0675     .ivsize         = AES_BLOCK_SIZE,
0676     .walksize       = 2 * AES_BLOCK_SIZE,
0677     .setkey         = xts_set_key,
0678     .encrypt        = xts_encrypt,
0679     .decrypt        = xts_decrypt,
0680 } };
0681 
0682 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
0683 
0684 static void aes_exit(void)
0685 {
0686     int i;
0687 
0688     for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
0689         simd_skcipher_free(aes_simd_algs[i]);
0690 
0691     crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0692 }
0693 
0694 static int __init aes_init(void)
0695 {
0696     struct simd_skcipher_alg *simd;
0697     const char *basename;
0698     const char *algname;
0699     const char *drvname;
0700     int err;
0701     int i;
0702 
0703     err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0704     if (err)
0705         return err;
0706 
0707     for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
0708         if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
0709             continue;
0710 
0711         algname = aes_algs[i].base.cra_name + 2;
0712         drvname = aes_algs[i].base.cra_driver_name + 2;
0713         basename = aes_algs[i].base.cra_driver_name;
0714         simd = simd_skcipher_create_compat(algname, drvname, basename);
0715         err = PTR_ERR(simd);
0716         if (IS_ERR(simd))
0717             goto unregister_simds;
0718 
0719         aes_simd_algs[i] = simd;
0720     }
0721 
0722     return 0;
0723 
0724 unregister_simds:
0725     aes_exit();
0726     return err;
0727 }
0728 
0729 module_cpu_feature_match(AES, aes_init);
0730 module_exit(aes_exit);