Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
0004  *
0005  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
0006  */
0007 
0008 #include <asm/neon.h>
0009 #include <asm/hwcap.h>
0010 #include <asm/simd.h>
0011 #include <crypto/aes.h>
0012 #include <crypto/ctr.h>
0013 #include <crypto/sha2.h>
0014 #include <crypto/internal/hash.h>
0015 #include <crypto/internal/simd.h>
0016 #include <crypto/internal/skcipher.h>
0017 #include <crypto/scatterwalk.h>
0018 #include <linux/module.h>
0019 #include <linux/cpufeature.h>
0020 #include <crypto/xts.h>
0021 
0022 #include "aes-ce-setkey.h"
0023 
0024 #ifdef USE_V8_CRYPTO_EXTENSIONS
0025 #define MODE            "ce"
0026 #define PRIO            300
0027 #define aes_expandkey       ce_aes_expandkey
0028 #define aes_ecb_encrypt     ce_aes_ecb_encrypt
0029 #define aes_ecb_decrypt     ce_aes_ecb_decrypt
0030 #define aes_cbc_encrypt     ce_aes_cbc_encrypt
0031 #define aes_cbc_decrypt     ce_aes_cbc_decrypt
0032 #define aes_cbc_cts_encrypt ce_aes_cbc_cts_encrypt
0033 #define aes_cbc_cts_decrypt ce_aes_cbc_cts_decrypt
0034 #define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
0035 #define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
0036 #define aes_ctr_encrypt     ce_aes_ctr_encrypt
0037 #define aes_xctr_encrypt    ce_aes_xctr_encrypt
0038 #define aes_xts_encrypt     ce_aes_xts_encrypt
0039 #define aes_xts_decrypt     ce_aes_xts_decrypt
0040 #define aes_mac_update      ce_aes_mac_update
0041 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
0042 #else
0043 #define MODE            "neon"
0044 #define PRIO            200
0045 #define aes_ecb_encrypt     neon_aes_ecb_encrypt
0046 #define aes_ecb_decrypt     neon_aes_ecb_decrypt
0047 #define aes_cbc_encrypt     neon_aes_cbc_encrypt
0048 #define aes_cbc_decrypt     neon_aes_cbc_decrypt
0049 #define aes_cbc_cts_encrypt neon_aes_cbc_cts_encrypt
0050 #define aes_cbc_cts_decrypt neon_aes_cbc_cts_decrypt
0051 #define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
0052 #define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
0053 #define aes_ctr_encrypt     neon_aes_ctr_encrypt
0054 #define aes_xctr_encrypt    neon_aes_xctr_encrypt
0055 #define aes_xts_encrypt     neon_aes_xts_encrypt
0056 #define aes_xts_decrypt     neon_aes_xts_decrypt
0057 #define aes_mac_update      neon_aes_mac_update
0058 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
0059 #endif
0060 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
0061 MODULE_ALIAS_CRYPTO("ecb(aes)");
0062 MODULE_ALIAS_CRYPTO("cbc(aes)");
0063 MODULE_ALIAS_CRYPTO("ctr(aes)");
0064 MODULE_ALIAS_CRYPTO("xts(aes)");
0065 MODULE_ALIAS_CRYPTO("xctr(aes)");
0066 #endif
0067 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
0068 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
0069 MODULE_ALIAS_CRYPTO("cmac(aes)");
0070 MODULE_ALIAS_CRYPTO("xcbc(aes)");
0071 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
0072 
0073 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
0074 MODULE_LICENSE("GPL v2");
0075 
0076 /* defined in aes-modes.S */
0077 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
0078                 int rounds, int blocks);
0079 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
0080                 int rounds, int blocks);
0081 
0082 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
0083                 int rounds, int blocks, u8 iv[]);
0084 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
0085                 int rounds, int blocks, u8 iv[]);
0086 
0087 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
0088                 int rounds, int bytes, u8 const iv[]);
0089 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
0090                 int rounds, int bytes, u8 const iv[]);
0091 
0092 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
0093                 int rounds, int bytes, u8 ctr[]);
0094 
0095 asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
0096                  int rounds, int bytes, u8 ctr[], int byte_ctr);
0097 
0098 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
0099                 int rounds, int bytes, u32 const rk2[], u8 iv[],
0100                 int first);
0101 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
0102                 int rounds, int bytes, u32 const rk2[], u8 iv[],
0103                 int first);
0104 
0105 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
0106                       int rounds, int blocks, u8 iv[],
0107                       u32 const rk2[]);
0108 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
0109                       int rounds, int blocks, u8 iv[],
0110                       u32 const rk2[]);
0111 
0112 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
0113                   int blocks, u8 dg[], int enc_before,
0114                   int enc_after);
0115 
0116 struct crypto_aes_xts_ctx {
0117     struct crypto_aes_ctx key1;
0118     struct crypto_aes_ctx __aligned(8) key2;
0119 };
0120 
0121 struct crypto_aes_essiv_cbc_ctx {
0122     struct crypto_aes_ctx key1;
0123     struct crypto_aes_ctx __aligned(8) key2;
0124     struct crypto_shash *hash;
0125 };
0126 
0127 struct mac_tfm_ctx {
0128     struct crypto_aes_ctx key;
0129     u8 __aligned(8) consts[];
0130 };
0131 
0132 struct mac_desc_ctx {
0133     unsigned int len;
0134     u8 dg[AES_BLOCK_SIZE];
0135 };
0136 
0137 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0138                    unsigned int key_len)
0139 {
0140     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0141 
0142     return aes_expandkey(ctx, in_key, key_len);
0143 }
0144 
0145 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
0146                       const u8 *in_key, unsigned int key_len)
0147 {
0148     struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0149     int ret;
0150 
0151     ret = xts_verify_key(tfm, in_key, key_len);
0152     if (ret)
0153         return ret;
0154 
0155     ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
0156     if (!ret)
0157         ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
0158                     key_len / 2);
0159     return ret;
0160 }
0161 
0162 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
0163                         const u8 *in_key,
0164                         unsigned int key_len)
0165 {
0166     struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0167     u8 digest[SHA256_DIGEST_SIZE];
0168     int ret;
0169 
0170     ret = aes_expandkey(&ctx->key1, in_key, key_len);
0171     if (ret)
0172         return ret;
0173 
0174     crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
0175 
0176     return aes_expandkey(&ctx->key2, digest, sizeof(digest));
0177 }
0178 
0179 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
0180 {
0181     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0182     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0183     int err, rounds = 6 + ctx->key_length / 4;
0184     struct skcipher_walk walk;
0185     unsigned int blocks;
0186 
0187     err = skcipher_walk_virt(&walk, req, false);
0188 
0189     while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
0190         kernel_neon_begin();
0191         aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0192                 ctx->key_enc, rounds, blocks);
0193         kernel_neon_end();
0194         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0195     }
0196     return err;
0197 }
0198 
0199 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
0200 {
0201     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0202     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0203     int err, rounds = 6 + ctx->key_length / 4;
0204     struct skcipher_walk walk;
0205     unsigned int blocks;
0206 
0207     err = skcipher_walk_virt(&walk, req, false);
0208 
0209     while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
0210         kernel_neon_begin();
0211         aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0212                 ctx->key_dec, rounds, blocks);
0213         kernel_neon_end();
0214         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0215     }
0216     return err;
0217 }
0218 
0219 static int cbc_encrypt_walk(struct skcipher_request *req,
0220                 struct skcipher_walk *walk)
0221 {
0222     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0223     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0224     int err = 0, rounds = 6 + ctx->key_length / 4;
0225     unsigned int blocks;
0226 
0227     while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
0228         kernel_neon_begin();
0229         aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
0230                 ctx->key_enc, rounds, blocks, walk->iv);
0231         kernel_neon_end();
0232         err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
0233     }
0234     return err;
0235 }
0236 
0237 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
0238 {
0239     struct skcipher_walk walk;
0240     int err;
0241 
0242     err = skcipher_walk_virt(&walk, req, false);
0243     if (err)
0244         return err;
0245     return cbc_encrypt_walk(req, &walk);
0246 }
0247 
0248 static int cbc_decrypt_walk(struct skcipher_request *req,
0249                 struct skcipher_walk *walk)
0250 {
0251     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0252     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0253     int err = 0, rounds = 6 + ctx->key_length / 4;
0254     unsigned int blocks;
0255 
0256     while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
0257         kernel_neon_begin();
0258         aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
0259                 ctx->key_dec, rounds, blocks, walk->iv);
0260         kernel_neon_end();
0261         err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
0262     }
0263     return err;
0264 }
0265 
0266 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
0267 {
0268     struct skcipher_walk walk;
0269     int err;
0270 
0271     err = skcipher_walk_virt(&walk, req, false);
0272     if (err)
0273         return err;
0274     return cbc_decrypt_walk(req, &walk);
0275 }
0276 
0277 static int cts_cbc_encrypt(struct skcipher_request *req)
0278 {
0279     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0280     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0281     int err, rounds = 6 + ctx->key_length / 4;
0282     int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
0283     struct scatterlist *src = req->src, *dst = req->dst;
0284     struct scatterlist sg_src[2], sg_dst[2];
0285     struct skcipher_request subreq;
0286     struct skcipher_walk walk;
0287 
0288     skcipher_request_set_tfm(&subreq, tfm);
0289     skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
0290                       NULL, NULL);
0291 
0292     if (req->cryptlen <= AES_BLOCK_SIZE) {
0293         if (req->cryptlen < AES_BLOCK_SIZE)
0294             return -EINVAL;
0295         cbc_blocks = 1;
0296     }
0297 
0298     if (cbc_blocks > 0) {
0299         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0300                        cbc_blocks * AES_BLOCK_SIZE,
0301                        req->iv);
0302 
0303         err = skcipher_walk_virt(&walk, &subreq, false) ?:
0304               cbc_encrypt_walk(&subreq, &walk);
0305         if (err)
0306             return err;
0307 
0308         if (req->cryptlen == AES_BLOCK_SIZE)
0309             return 0;
0310 
0311         dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
0312         if (req->dst != req->src)
0313             dst = scatterwalk_ffwd(sg_dst, req->dst,
0314                            subreq.cryptlen);
0315     }
0316 
0317     /* handle ciphertext stealing */
0318     skcipher_request_set_crypt(&subreq, src, dst,
0319                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
0320                    req->iv);
0321 
0322     err = skcipher_walk_virt(&walk, &subreq, false);
0323     if (err)
0324         return err;
0325 
0326     kernel_neon_begin();
0327     aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0328                 ctx->key_enc, rounds, walk.nbytes, walk.iv);
0329     kernel_neon_end();
0330 
0331     return skcipher_walk_done(&walk, 0);
0332 }
0333 
0334 static int cts_cbc_decrypt(struct skcipher_request *req)
0335 {
0336     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0337     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0338     int err, rounds = 6 + ctx->key_length / 4;
0339     int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
0340     struct scatterlist *src = req->src, *dst = req->dst;
0341     struct scatterlist sg_src[2], sg_dst[2];
0342     struct skcipher_request subreq;
0343     struct skcipher_walk walk;
0344 
0345     skcipher_request_set_tfm(&subreq, tfm);
0346     skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
0347                       NULL, NULL);
0348 
0349     if (req->cryptlen <= AES_BLOCK_SIZE) {
0350         if (req->cryptlen < AES_BLOCK_SIZE)
0351             return -EINVAL;
0352         cbc_blocks = 1;
0353     }
0354 
0355     if (cbc_blocks > 0) {
0356         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0357                        cbc_blocks * AES_BLOCK_SIZE,
0358                        req->iv);
0359 
0360         err = skcipher_walk_virt(&walk, &subreq, false) ?:
0361               cbc_decrypt_walk(&subreq, &walk);
0362         if (err)
0363             return err;
0364 
0365         if (req->cryptlen == AES_BLOCK_SIZE)
0366             return 0;
0367 
0368         dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
0369         if (req->dst != req->src)
0370             dst = scatterwalk_ffwd(sg_dst, req->dst,
0371                            subreq.cryptlen);
0372     }
0373 
0374     /* handle ciphertext stealing */
0375     skcipher_request_set_crypt(&subreq, src, dst,
0376                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
0377                    req->iv);
0378 
0379     err = skcipher_walk_virt(&walk, &subreq, false);
0380     if (err)
0381         return err;
0382 
0383     kernel_neon_begin();
0384     aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0385                 ctx->key_dec, rounds, walk.nbytes, walk.iv);
0386     kernel_neon_end();
0387 
0388     return skcipher_walk_done(&walk, 0);
0389 }
0390 
0391 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
0392 {
0393     struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0394 
0395     ctx->hash = crypto_alloc_shash("sha256", 0, 0);
0396 
0397     return PTR_ERR_OR_ZERO(ctx->hash);
0398 }
0399 
0400 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
0401 {
0402     struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0403 
0404     crypto_free_shash(ctx->hash);
0405 }
0406 
0407 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
0408 {
0409     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0410     struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0411     int err, rounds = 6 + ctx->key1.key_length / 4;
0412     struct skcipher_walk walk;
0413     unsigned int blocks;
0414 
0415     err = skcipher_walk_virt(&walk, req, false);
0416 
0417     blocks = walk.nbytes / AES_BLOCK_SIZE;
0418     if (blocks) {
0419         kernel_neon_begin();
0420         aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0421                       ctx->key1.key_enc, rounds, blocks,
0422                       req->iv, ctx->key2.key_enc);
0423         kernel_neon_end();
0424         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0425     }
0426     return err ?: cbc_encrypt_walk(req, &walk);
0427 }
0428 
0429 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
0430 {
0431     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0432     struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0433     int err, rounds = 6 + ctx->key1.key_length / 4;
0434     struct skcipher_walk walk;
0435     unsigned int blocks;
0436 
0437     err = skcipher_walk_virt(&walk, req, false);
0438 
0439     blocks = walk.nbytes / AES_BLOCK_SIZE;
0440     if (blocks) {
0441         kernel_neon_begin();
0442         aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0443                       ctx->key1.key_dec, rounds, blocks,
0444                       req->iv, ctx->key2.key_enc);
0445         kernel_neon_end();
0446         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0447     }
0448     return err ?: cbc_decrypt_walk(req, &walk);
0449 }
0450 
0451 static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
0452 {
0453     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0454     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0455     int err, rounds = 6 + ctx->key_length / 4;
0456     struct skcipher_walk walk;
0457     unsigned int byte_ctr = 0;
0458 
0459     err = skcipher_walk_virt(&walk, req, false);
0460 
0461     while (walk.nbytes > 0) {
0462         const u8 *src = walk.src.virt.addr;
0463         unsigned int nbytes = walk.nbytes;
0464         u8 *dst = walk.dst.virt.addr;
0465         u8 buf[AES_BLOCK_SIZE];
0466 
0467         /*
0468          * If given less than 16 bytes, we must copy the partial block
0469          * into a temporary buffer of 16 bytes to avoid out of bounds
0470          * reads and writes.  Furthermore, this code is somewhat unusual
0471          * in that it expects the end of the data to be at the end of
0472          * the temporary buffer, rather than the start of the data at
0473          * the start of the temporary buffer.
0474          */
0475         if (unlikely(nbytes < AES_BLOCK_SIZE))
0476             src = dst = memcpy(buf + sizeof(buf) - nbytes,
0477                        src, nbytes);
0478         else if (nbytes < walk.total)
0479             nbytes &= ~(AES_BLOCK_SIZE - 1);
0480 
0481         kernel_neon_begin();
0482         aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
0483                          walk.iv, byte_ctr);
0484         kernel_neon_end();
0485 
0486         if (unlikely(nbytes < AES_BLOCK_SIZE))
0487             memcpy(walk.dst.virt.addr,
0488                    buf + sizeof(buf) - nbytes, nbytes);
0489         byte_ctr += nbytes;
0490 
0491         err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0492     }
0493 
0494     return err;
0495 }
0496 
0497 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
0498 {
0499     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0500     struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
0501     int err, rounds = 6 + ctx->key_length / 4;
0502     struct skcipher_walk walk;
0503 
0504     err = skcipher_walk_virt(&walk, req, false);
0505 
0506     while (walk.nbytes > 0) {
0507         const u8 *src = walk.src.virt.addr;
0508         unsigned int nbytes = walk.nbytes;
0509         u8 *dst = walk.dst.virt.addr;
0510         u8 buf[AES_BLOCK_SIZE];
0511 
0512         /*
0513          * If given less than 16 bytes, we must copy the partial block
0514          * into a temporary buffer of 16 bytes to avoid out of bounds
0515          * reads and writes.  Furthermore, this code is somewhat unusual
0516          * in that it expects the end of the data to be at the end of
0517          * the temporary buffer, rather than the start of the data at
0518          * the start of the temporary buffer.
0519          */
0520         if (unlikely(nbytes < AES_BLOCK_SIZE))
0521             src = dst = memcpy(buf + sizeof(buf) - nbytes,
0522                        src, nbytes);
0523         else if (nbytes < walk.total)
0524             nbytes &= ~(AES_BLOCK_SIZE - 1);
0525 
0526         kernel_neon_begin();
0527         aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
0528                 walk.iv);
0529         kernel_neon_end();
0530 
0531         if (unlikely(nbytes < AES_BLOCK_SIZE))
0532             memcpy(walk.dst.virt.addr,
0533                    buf + sizeof(buf) - nbytes, nbytes);
0534 
0535         err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0536     }
0537 
0538     return err;
0539 }
0540 
0541 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
0542 {
0543     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0544     struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0545     int err, first, rounds = 6 + ctx->key1.key_length / 4;
0546     int tail = req->cryptlen % AES_BLOCK_SIZE;
0547     struct scatterlist sg_src[2], sg_dst[2];
0548     struct skcipher_request subreq;
0549     struct scatterlist *src, *dst;
0550     struct skcipher_walk walk;
0551 
0552     if (req->cryptlen < AES_BLOCK_SIZE)
0553         return -EINVAL;
0554 
0555     err = skcipher_walk_virt(&walk, req, false);
0556 
0557     if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
0558         int xts_blocks = DIV_ROUND_UP(req->cryptlen,
0559                           AES_BLOCK_SIZE) - 2;
0560 
0561         skcipher_walk_abort(&walk);
0562 
0563         skcipher_request_set_tfm(&subreq, tfm);
0564         skcipher_request_set_callback(&subreq,
0565                           skcipher_request_flags(req),
0566                           NULL, NULL);
0567         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0568                        xts_blocks * AES_BLOCK_SIZE,
0569                        req->iv);
0570         req = &subreq;
0571         err = skcipher_walk_virt(&walk, req, false);
0572     } else {
0573         tail = 0;
0574     }
0575 
0576     for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
0577         int nbytes = walk.nbytes;
0578 
0579         if (walk.nbytes < walk.total)
0580             nbytes &= ~(AES_BLOCK_SIZE - 1);
0581 
0582         kernel_neon_begin();
0583         aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0584                 ctx->key1.key_enc, rounds, nbytes,
0585                 ctx->key2.key_enc, walk.iv, first);
0586         kernel_neon_end();
0587         err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0588     }
0589 
0590     if (err || likely(!tail))
0591         return err;
0592 
0593     dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
0594     if (req->dst != req->src)
0595         dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
0596 
0597     skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
0598                    req->iv);
0599 
0600     err = skcipher_walk_virt(&walk, &subreq, false);
0601     if (err)
0602         return err;
0603 
0604     kernel_neon_begin();
0605     aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0606             ctx->key1.key_enc, rounds, walk.nbytes,
0607             ctx->key2.key_enc, walk.iv, first);
0608     kernel_neon_end();
0609 
0610     return skcipher_walk_done(&walk, 0);
0611 }
0612 
0613 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
0614 {
0615     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0616     struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0617     int err, first, rounds = 6 + ctx->key1.key_length / 4;
0618     int tail = req->cryptlen % AES_BLOCK_SIZE;
0619     struct scatterlist sg_src[2], sg_dst[2];
0620     struct skcipher_request subreq;
0621     struct scatterlist *src, *dst;
0622     struct skcipher_walk walk;
0623 
0624     if (req->cryptlen < AES_BLOCK_SIZE)
0625         return -EINVAL;
0626 
0627     err = skcipher_walk_virt(&walk, req, false);
0628 
0629     if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
0630         int xts_blocks = DIV_ROUND_UP(req->cryptlen,
0631                           AES_BLOCK_SIZE) - 2;
0632 
0633         skcipher_walk_abort(&walk);
0634 
0635         skcipher_request_set_tfm(&subreq, tfm);
0636         skcipher_request_set_callback(&subreq,
0637                           skcipher_request_flags(req),
0638                           NULL, NULL);
0639         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0640                        xts_blocks * AES_BLOCK_SIZE,
0641                        req->iv);
0642         req = &subreq;
0643         err = skcipher_walk_virt(&walk, req, false);
0644     } else {
0645         tail = 0;
0646     }
0647 
0648     for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
0649         int nbytes = walk.nbytes;
0650 
0651         if (walk.nbytes < walk.total)
0652             nbytes &= ~(AES_BLOCK_SIZE - 1);
0653 
0654         kernel_neon_begin();
0655         aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0656                 ctx->key1.key_dec, rounds, nbytes,
0657                 ctx->key2.key_enc, walk.iv, first);
0658         kernel_neon_end();
0659         err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0660     }
0661 
0662     if (err || likely(!tail))
0663         return err;
0664 
0665     dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
0666     if (req->dst != req->src)
0667         dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
0668 
0669     skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
0670                    req->iv);
0671 
0672     err = skcipher_walk_virt(&walk, &subreq, false);
0673     if (err)
0674         return err;
0675 
0676 
0677     kernel_neon_begin();
0678     aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0679             ctx->key1.key_dec, rounds, walk.nbytes,
0680             ctx->key2.key_enc, walk.iv, first);
0681     kernel_neon_end();
0682 
0683     return skcipher_walk_done(&walk, 0);
0684 }
0685 
0686 static struct skcipher_alg aes_algs[] = { {
0687 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
0688     .base = {
0689         .cra_name       = "ecb(aes)",
0690         .cra_driver_name    = "ecb-aes-" MODE,
0691         .cra_priority       = PRIO,
0692         .cra_blocksize      = AES_BLOCK_SIZE,
0693         .cra_ctxsize        = sizeof(struct crypto_aes_ctx),
0694         .cra_module     = THIS_MODULE,
0695     },
0696     .min_keysize    = AES_MIN_KEY_SIZE,
0697     .max_keysize    = AES_MAX_KEY_SIZE,
0698     .setkey     = skcipher_aes_setkey,
0699     .encrypt    = ecb_encrypt,
0700     .decrypt    = ecb_decrypt,
0701 }, {
0702     .base = {
0703         .cra_name       = "cbc(aes)",
0704         .cra_driver_name    = "cbc-aes-" MODE,
0705         .cra_priority       = PRIO,
0706         .cra_blocksize      = AES_BLOCK_SIZE,
0707         .cra_ctxsize        = sizeof(struct crypto_aes_ctx),
0708         .cra_module     = THIS_MODULE,
0709     },
0710     .min_keysize    = AES_MIN_KEY_SIZE,
0711     .max_keysize    = AES_MAX_KEY_SIZE,
0712     .ivsize     = AES_BLOCK_SIZE,
0713     .setkey     = skcipher_aes_setkey,
0714     .encrypt    = cbc_encrypt,
0715     .decrypt    = cbc_decrypt,
0716 }, {
0717     .base = {
0718         .cra_name       = "ctr(aes)",
0719         .cra_driver_name    = "ctr-aes-" MODE,
0720         .cra_priority       = PRIO,
0721         .cra_blocksize      = 1,
0722         .cra_ctxsize        = sizeof(struct crypto_aes_ctx),
0723         .cra_module     = THIS_MODULE,
0724     },
0725     .min_keysize    = AES_MIN_KEY_SIZE,
0726     .max_keysize    = AES_MAX_KEY_SIZE,
0727     .ivsize     = AES_BLOCK_SIZE,
0728     .chunksize  = AES_BLOCK_SIZE,
0729     .setkey     = skcipher_aes_setkey,
0730     .encrypt    = ctr_encrypt,
0731     .decrypt    = ctr_encrypt,
0732 }, {
0733     .base = {
0734         .cra_name       = "xctr(aes)",
0735         .cra_driver_name    = "xctr-aes-" MODE,
0736         .cra_priority       = PRIO,
0737         .cra_blocksize      = 1,
0738         .cra_ctxsize        = sizeof(struct crypto_aes_ctx),
0739         .cra_module     = THIS_MODULE,
0740     },
0741     .min_keysize    = AES_MIN_KEY_SIZE,
0742     .max_keysize    = AES_MAX_KEY_SIZE,
0743     .ivsize     = AES_BLOCK_SIZE,
0744     .chunksize  = AES_BLOCK_SIZE,
0745     .setkey     = skcipher_aes_setkey,
0746     .encrypt    = xctr_encrypt,
0747     .decrypt    = xctr_encrypt,
0748 }, {
0749     .base = {
0750         .cra_name       = "xts(aes)",
0751         .cra_driver_name    = "xts-aes-" MODE,
0752         .cra_priority       = PRIO,
0753         .cra_blocksize      = AES_BLOCK_SIZE,
0754         .cra_ctxsize        = sizeof(struct crypto_aes_xts_ctx),
0755         .cra_module     = THIS_MODULE,
0756     },
0757     .min_keysize    = 2 * AES_MIN_KEY_SIZE,
0758     .max_keysize    = 2 * AES_MAX_KEY_SIZE,
0759     .ivsize     = AES_BLOCK_SIZE,
0760     .walksize   = 2 * AES_BLOCK_SIZE,
0761     .setkey     = xts_set_key,
0762     .encrypt    = xts_encrypt,
0763     .decrypt    = xts_decrypt,
0764 }, {
0765 #endif
0766     .base = {
0767         .cra_name       = "cts(cbc(aes))",
0768         .cra_driver_name    = "cts-cbc-aes-" MODE,
0769         .cra_priority       = PRIO,
0770         .cra_blocksize      = AES_BLOCK_SIZE,
0771         .cra_ctxsize        = sizeof(struct crypto_aes_ctx),
0772         .cra_module     = THIS_MODULE,
0773     },
0774     .min_keysize    = AES_MIN_KEY_SIZE,
0775     .max_keysize    = AES_MAX_KEY_SIZE,
0776     .ivsize     = AES_BLOCK_SIZE,
0777     .walksize   = 2 * AES_BLOCK_SIZE,
0778     .setkey     = skcipher_aes_setkey,
0779     .encrypt    = cts_cbc_encrypt,
0780     .decrypt    = cts_cbc_decrypt,
0781 }, {
0782     .base = {
0783         .cra_name       = "essiv(cbc(aes),sha256)",
0784         .cra_driver_name    = "essiv-cbc-aes-sha256-" MODE,
0785         .cra_priority       = PRIO + 1,
0786         .cra_blocksize      = AES_BLOCK_SIZE,
0787         .cra_ctxsize        = sizeof(struct crypto_aes_essiv_cbc_ctx),
0788         .cra_module     = THIS_MODULE,
0789     },
0790     .min_keysize    = AES_MIN_KEY_SIZE,
0791     .max_keysize    = AES_MAX_KEY_SIZE,
0792     .ivsize     = AES_BLOCK_SIZE,
0793     .setkey     = essiv_cbc_set_key,
0794     .encrypt    = essiv_cbc_encrypt,
0795     .decrypt    = essiv_cbc_decrypt,
0796     .init       = essiv_cbc_init_tfm,
0797     .exit       = essiv_cbc_exit_tfm,
0798 } };
0799 
0800 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
0801              unsigned int key_len)
0802 {
0803     struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
0804 
0805     return aes_expandkey(&ctx->key, in_key, key_len);
0806 }
0807 
0808 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
0809 {
0810     u64 a = be64_to_cpu(x->a);
0811     u64 b = be64_to_cpu(x->b);
0812 
0813     y->a = cpu_to_be64((a << 1) | (b >> 63));
0814     y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
0815 }
0816 
0817 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
0818                unsigned int key_len)
0819 {
0820     struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
0821     be128 *consts = (be128 *)ctx->consts;
0822     int rounds = 6 + key_len / 4;
0823     int err;
0824 
0825     err = cbcmac_setkey(tfm, in_key, key_len);
0826     if (err)
0827         return err;
0828 
0829     /* encrypt the zero vector */
0830     kernel_neon_begin();
0831     aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
0832             rounds, 1);
0833     kernel_neon_end();
0834 
0835     cmac_gf128_mul_by_x(consts, consts);
0836     cmac_gf128_mul_by_x(consts + 1, consts);
0837 
0838     return 0;
0839 }
0840 
0841 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
0842                unsigned int key_len)
0843 {
0844     static u8 const ks[3][AES_BLOCK_SIZE] = {
0845         { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
0846         { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
0847         { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
0848     };
0849 
0850     struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
0851     int rounds = 6 + key_len / 4;
0852     u8 key[AES_BLOCK_SIZE];
0853     int err;
0854 
0855     err = cbcmac_setkey(tfm, in_key, key_len);
0856     if (err)
0857         return err;
0858 
0859     kernel_neon_begin();
0860     aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
0861     aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
0862     kernel_neon_end();
0863 
0864     return cbcmac_setkey(tfm, key, sizeof(key));
0865 }
0866 
0867 static int mac_init(struct shash_desc *desc)
0868 {
0869     struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
0870 
0871     memset(ctx->dg, 0, AES_BLOCK_SIZE);
0872     ctx->len = 0;
0873 
0874     return 0;
0875 }
0876 
0877 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
0878               u8 dg[], int enc_before, int enc_after)
0879 {
0880     int rounds = 6 + ctx->key_length / 4;
0881 
0882     if (crypto_simd_usable()) {
0883         int rem;
0884 
0885         do {
0886             kernel_neon_begin();
0887             rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
0888                          dg, enc_before, enc_after);
0889             kernel_neon_end();
0890             in += (blocks - rem) * AES_BLOCK_SIZE;
0891             blocks = rem;
0892             enc_before = 0;
0893         } while (blocks);
0894     } else {
0895         if (enc_before)
0896             aes_encrypt(ctx, dg, dg);
0897 
0898         while (blocks--) {
0899             crypto_xor(dg, in, AES_BLOCK_SIZE);
0900             in += AES_BLOCK_SIZE;
0901 
0902             if (blocks || enc_after)
0903                 aes_encrypt(ctx, dg, dg);
0904         }
0905     }
0906 }
0907 
0908 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
0909 {
0910     struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
0911     struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
0912 
0913     while (len > 0) {
0914         unsigned int l;
0915 
0916         if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
0917             (ctx->len + len) > AES_BLOCK_SIZE) {
0918 
0919             int blocks = len / AES_BLOCK_SIZE;
0920 
0921             len %= AES_BLOCK_SIZE;
0922 
0923             mac_do_update(&tctx->key, p, blocks, ctx->dg,
0924                       (ctx->len != 0), (len != 0));
0925 
0926             p += blocks * AES_BLOCK_SIZE;
0927 
0928             if (!len) {
0929                 ctx->len = AES_BLOCK_SIZE;
0930                 break;
0931             }
0932             ctx->len = 0;
0933         }
0934 
0935         l = min(len, AES_BLOCK_SIZE - ctx->len);
0936 
0937         if (l <= AES_BLOCK_SIZE) {
0938             crypto_xor(ctx->dg + ctx->len, p, l);
0939             ctx->len += l;
0940             len -= l;
0941             p += l;
0942         }
0943     }
0944 
0945     return 0;
0946 }
0947 
0948 static int cbcmac_final(struct shash_desc *desc, u8 *out)
0949 {
0950     struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
0951     struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
0952 
0953     mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
0954 
0955     memcpy(out, ctx->dg, AES_BLOCK_SIZE);
0956 
0957     return 0;
0958 }
0959 
0960 static int cmac_final(struct shash_desc *desc, u8 *out)
0961 {
0962     struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
0963     struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
0964     u8 *consts = tctx->consts;
0965 
0966     if (ctx->len != AES_BLOCK_SIZE) {
0967         ctx->dg[ctx->len] ^= 0x80;
0968         consts += AES_BLOCK_SIZE;
0969     }
0970 
0971     mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
0972 
0973     memcpy(out, ctx->dg, AES_BLOCK_SIZE);
0974 
0975     return 0;
0976 }
0977 
0978 static struct shash_alg mac_algs[] = { {
0979     .base.cra_name      = "cmac(aes)",
0980     .base.cra_driver_name   = "cmac-aes-" MODE,
0981     .base.cra_priority  = PRIO,
0982     .base.cra_blocksize = AES_BLOCK_SIZE,
0983     .base.cra_ctxsize   = sizeof(struct mac_tfm_ctx) +
0984                   2 * AES_BLOCK_SIZE,
0985     .base.cra_module    = THIS_MODULE,
0986 
0987     .digestsize     = AES_BLOCK_SIZE,
0988     .init           = mac_init,
0989     .update         = mac_update,
0990     .final          = cmac_final,
0991     .setkey         = cmac_setkey,
0992     .descsize       = sizeof(struct mac_desc_ctx),
0993 }, {
0994     .base.cra_name      = "xcbc(aes)",
0995     .base.cra_driver_name   = "xcbc-aes-" MODE,
0996     .base.cra_priority  = PRIO,
0997     .base.cra_blocksize = AES_BLOCK_SIZE,
0998     .base.cra_ctxsize   = sizeof(struct mac_tfm_ctx) +
0999                   2 * AES_BLOCK_SIZE,
1000     .base.cra_module    = THIS_MODULE,
1001 
1002     .digestsize     = AES_BLOCK_SIZE,
1003     .init           = mac_init,
1004     .update         = mac_update,
1005     .final          = cmac_final,
1006     .setkey         = xcbc_setkey,
1007     .descsize       = sizeof(struct mac_desc_ctx),
1008 }, {
1009     .base.cra_name      = "cbcmac(aes)",
1010     .base.cra_driver_name   = "cbcmac-aes-" MODE,
1011     .base.cra_priority  = PRIO,
1012     .base.cra_blocksize = 1,
1013     .base.cra_ctxsize   = sizeof(struct mac_tfm_ctx),
1014     .base.cra_module    = THIS_MODULE,
1015 
1016     .digestsize     = AES_BLOCK_SIZE,
1017     .init           = mac_init,
1018     .update         = mac_update,
1019     .final          = cbcmac_final,
1020     .setkey         = cbcmac_setkey,
1021     .descsize       = sizeof(struct mac_desc_ctx),
1022 } };
1023 
1024 static void aes_exit(void)
1025 {
1026     crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1027     crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1028 }
1029 
1030 static int __init aes_init(void)
1031 {
1032     int err;
1033 
1034     err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1035     if (err)
1036         return err;
1037 
1038     err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1039     if (err)
1040         goto unregister_ciphers;
1041 
1042     return 0;
1043 
1044 unregister_ciphers:
1045     crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1046     return err;
1047 }
1048 
1049 #ifdef USE_V8_CRYPTO_EXTENSIONS
1050 module_cpu_feature_match(AES, aes_init);
1051 #else
1052 module_init(aes_init);
1053 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1054 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1055 EXPORT_SYMBOL(neon_aes_ctr_encrypt);
1056 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1057 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1058 #endif
1059 module_exit(aes_exit);