Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * Bit sliced AES using NEON instructions
0004  *
0005  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
0006  */
0007 
0008 #include <asm/neon.h>
0009 #include <asm/simd.h>
0010 #include <crypto/aes.h>
0011 #include <crypto/ctr.h>
0012 #include <crypto/internal/cipher.h>
0013 #include <crypto/internal/simd.h>
0014 #include <crypto/internal/skcipher.h>
0015 #include <crypto/scatterwalk.h>
0016 #include <crypto/xts.h>
0017 #include <linux/module.h>
0018 
0019 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
0020 MODULE_LICENSE("GPL v2");
0021 
0022 MODULE_ALIAS_CRYPTO("ecb(aes)");
0023 MODULE_ALIAS_CRYPTO("cbc(aes)-all");
0024 MODULE_ALIAS_CRYPTO("ctr(aes)");
0025 MODULE_ALIAS_CRYPTO("xts(aes)");
0026 
0027 MODULE_IMPORT_NS(CRYPTO_INTERNAL);
0028 
0029 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
0030 
0031 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
0032                   int rounds, int blocks);
0033 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
0034                   int rounds, int blocks);
0035 
0036 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
0037                   int rounds, int blocks, u8 iv[]);
0038 
0039 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
0040                   int rounds, int blocks, u8 ctr[]);
0041 
0042 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
0043                   int rounds, int blocks, u8 iv[], int);
0044 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
0045                   int rounds, int blocks, u8 iv[], int);
0046 
0047 struct aesbs_ctx {
0048     int rounds;
0049     u8  rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
0050 };
0051 
0052 struct aesbs_cbc_ctx {
0053     struct aesbs_ctx    key;
0054     struct crypto_skcipher  *enc_tfm;
0055 };
0056 
0057 struct aesbs_xts_ctx {
0058     struct aesbs_ctx    key;
0059     struct crypto_cipher    *cts_tfm;
0060     struct crypto_cipher    *tweak_tfm;
0061 };
0062 
0063 struct aesbs_ctr_ctx {
0064     struct aesbs_ctx    key;        /* must be first member */
0065     struct crypto_aes_ctx   fallback;
0066 };
0067 
0068 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0069             unsigned int key_len)
0070 {
0071     struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0072     struct crypto_aes_ctx rk;
0073     int err;
0074 
0075     err = aes_expandkey(&rk, in_key, key_len);
0076     if (err)
0077         return err;
0078 
0079     ctx->rounds = 6 + key_len / 4;
0080 
0081     kernel_neon_begin();
0082     aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
0083     kernel_neon_end();
0084 
0085     return 0;
0086 }
0087 
0088 static int __ecb_crypt(struct skcipher_request *req,
0089                void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0090                   int rounds, int blocks))
0091 {
0092     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0093     struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0094     struct skcipher_walk walk;
0095     int err;
0096 
0097     err = skcipher_walk_virt(&walk, req, false);
0098 
0099     while (walk.nbytes >= AES_BLOCK_SIZE) {
0100         unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0101 
0102         if (walk.nbytes < walk.total)
0103             blocks = round_down(blocks,
0104                         walk.stride / AES_BLOCK_SIZE);
0105 
0106         kernel_neon_begin();
0107         fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
0108            ctx->rounds, blocks);
0109         kernel_neon_end();
0110         err = skcipher_walk_done(&walk,
0111                      walk.nbytes - blocks * AES_BLOCK_SIZE);
0112     }
0113 
0114     return err;
0115 }
0116 
0117 static int ecb_encrypt(struct skcipher_request *req)
0118 {
0119     return __ecb_crypt(req, aesbs_ecb_encrypt);
0120 }
0121 
0122 static int ecb_decrypt(struct skcipher_request *req)
0123 {
0124     return __ecb_crypt(req, aesbs_ecb_decrypt);
0125 }
0126 
0127 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0128                 unsigned int key_len)
0129 {
0130     struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0131     struct crypto_aes_ctx rk;
0132     int err;
0133 
0134     err = aes_expandkey(&rk, in_key, key_len);
0135     if (err)
0136         return err;
0137 
0138     ctx->key.rounds = 6 + key_len / 4;
0139 
0140     kernel_neon_begin();
0141     aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
0142     kernel_neon_end();
0143     memzero_explicit(&rk, sizeof(rk));
0144 
0145     return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
0146 }
0147 
0148 static int cbc_encrypt(struct skcipher_request *req)
0149 {
0150     struct skcipher_request *subreq = skcipher_request_ctx(req);
0151     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0152     struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0153 
0154     skcipher_request_set_tfm(subreq, ctx->enc_tfm);
0155     skcipher_request_set_callback(subreq,
0156                       skcipher_request_flags(req),
0157                       NULL, NULL);
0158     skcipher_request_set_crypt(subreq, req->src, req->dst,
0159                    req->cryptlen, req->iv);
0160 
0161     return crypto_skcipher_encrypt(subreq);
0162 }
0163 
0164 static int cbc_decrypt(struct skcipher_request *req)
0165 {
0166     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0167     struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0168     struct skcipher_walk walk;
0169     int err;
0170 
0171     err = skcipher_walk_virt(&walk, req, false);
0172 
0173     while (walk.nbytes >= AES_BLOCK_SIZE) {
0174         unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0175 
0176         if (walk.nbytes < walk.total)
0177             blocks = round_down(blocks,
0178                         walk.stride / AES_BLOCK_SIZE);
0179 
0180         kernel_neon_begin();
0181         aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0182                   ctx->key.rk, ctx->key.rounds, blocks,
0183                   walk.iv);
0184         kernel_neon_end();
0185         err = skcipher_walk_done(&walk,
0186                      walk.nbytes - blocks * AES_BLOCK_SIZE);
0187     }
0188 
0189     return err;
0190 }
0191 
0192 static int cbc_init(struct crypto_skcipher *tfm)
0193 {
0194     struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0195     unsigned int reqsize;
0196 
0197     ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
0198                          CRYPTO_ALG_NEED_FALLBACK);
0199     if (IS_ERR(ctx->enc_tfm))
0200         return PTR_ERR(ctx->enc_tfm);
0201 
0202     reqsize = sizeof(struct skcipher_request);
0203     reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
0204     crypto_skcipher_set_reqsize(tfm, reqsize);
0205 
0206     return 0;
0207 }
0208 
0209 static void cbc_exit(struct crypto_skcipher *tfm)
0210 {
0211     struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0212 
0213     crypto_free_skcipher(ctx->enc_tfm);
0214 }
0215 
0216 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
0217                  unsigned int key_len)
0218 {
0219     struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0220     int err;
0221 
0222     err = aes_expandkey(&ctx->fallback, in_key, key_len);
0223     if (err)
0224         return err;
0225 
0226     ctx->key.rounds = 6 + key_len / 4;
0227 
0228     kernel_neon_begin();
0229     aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
0230     kernel_neon_end();
0231 
0232     return 0;
0233 }
0234 
0235 static int ctr_encrypt(struct skcipher_request *req)
0236 {
0237     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0238     struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0239     struct skcipher_walk walk;
0240     u8 buf[AES_BLOCK_SIZE];
0241     int err;
0242 
0243     err = skcipher_walk_virt(&walk, req, false);
0244 
0245     while (walk.nbytes > 0) {
0246         const u8 *src = walk.src.virt.addr;
0247         u8 *dst = walk.dst.virt.addr;
0248         int bytes = walk.nbytes;
0249 
0250         if (unlikely(bytes < AES_BLOCK_SIZE))
0251             src = dst = memcpy(buf + sizeof(buf) - bytes,
0252                        src, bytes);
0253         else if (walk.nbytes < walk.total)
0254             bytes &= ~(8 * AES_BLOCK_SIZE - 1);
0255 
0256         kernel_neon_begin();
0257         aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
0258         kernel_neon_end();
0259 
0260         if (unlikely(bytes < AES_BLOCK_SIZE))
0261             memcpy(walk.dst.virt.addr,
0262                    buf + sizeof(buf) - bytes, bytes);
0263 
0264         err = skcipher_walk_done(&walk, walk.nbytes - bytes);
0265     }
0266 
0267     return err;
0268 }
0269 
0270 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
0271 {
0272     struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0273     unsigned long flags;
0274 
0275     /*
0276      * Temporarily disable interrupts to avoid races where
0277      * cachelines are evicted when the CPU is interrupted
0278      * to do something else.
0279      */
0280     local_irq_save(flags);
0281     aes_encrypt(&ctx->fallback, dst, src);
0282     local_irq_restore(flags);
0283 }
0284 
0285 static int ctr_encrypt_sync(struct skcipher_request *req)
0286 {
0287     if (!crypto_simd_usable())
0288         return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
0289 
0290     return ctr_encrypt(req);
0291 }
0292 
0293 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0294                 unsigned int key_len)
0295 {
0296     struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0297     int err;
0298 
0299     err = xts_verify_key(tfm, in_key, key_len);
0300     if (err)
0301         return err;
0302 
0303     key_len /= 2;
0304     err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
0305     if (err)
0306         return err;
0307     err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
0308     if (err)
0309         return err;
0310 
0311     return aesbs_setkey(tfm, in_key, key_len);
0312 }
0313 
0314 static int xts_init(struct crypto_skcipher *tfm)
0315 {
0316     struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0317 
0318     ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
0319     if (IS_ERR(ctx->cts_tfm))
0320         return PTR_ERR(ctx->cts_tfm);
0321 
0322     ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
0323     if (IS_ERR(ctx->tweak_tfm))
0324         crypto_free_cipher(ctx->cts_tfm);
0325 
0326     return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
0327 }
0328 
0329 static void xts_exit(struct crypto_skcipher *tfm)
0330 {
0331     struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0332 
0333     crypto_free_cipher(ctx->tweak_tfm);
0334     crypto_free_cipher(ctx->cts_tfm);
0335 }
0336 
0337 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
0338                void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0339                   int rounds, int blocks, u8 iv[], int))
0340 {
0341     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0342     struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0343     int tail = req->cryptlen % AES_BLOCK_SIZE;
0344     struct skcipher_request subreq;
0345     u8 buf[2 * AES_BLOCK_SIZE];
0346     struct skcipher_walk walk;
0347     int err;
0348 
0349     if (req->cryptlen < AES_BLOCK_SIZE)
0350         return -EINVAL;
0351 
0352     if (unlikely(tail)) {
0353         skcipher_request_set_tfm(&subreq, tfm);
0354         skcipher_request_set_callback(&subreq,
0355                           skcipher_request_flags(req),
0356                           NULL, NULL);
0357         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0358                        req->cryptlen - tail, req->iv);
0359         req = &subreq;
0360     }
0361 
0362     err = skcipher_walk_virt(&walk, req, true);
0363     if (err)
0364         return err;
0365 
0366     crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
0367 
0368     while (walk.nbytes >= AES_BLOCK_SIZE) {
0369         unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0370         int reorder_last_tweak = !encrypt && tail > 0;
0371 
0372         if (walk.nbytes < walk.total) {
0373             blocks = round_down(blocks,
0374                         walk.stride / AES_BLOCK_SIZE);
0375             reorder_last_tweak = 0;
0376         }
0377 
0378         kernel_neon_begin();
0379         fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
0380            ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
0381         kernel_neon_end();
0382         err = skcipher_walk_done(&walk,
0383                      walk.nbytes - blocks * AES_BLOCK_SIZE);
0384     }
0385 
0386     if (err || likely(!tail))
0387         return err;
0388 
0389     /* handle ciphertext stealing */
0390     scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
0391                  AES_BLOCK_SIZE, 0);
0392     memcpy(buf + AES_BLOCK_SIZE, buf, tail);
0393     scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
0394 
0395     crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
0396 
0397     if (encrypt)
0398         crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
0399     else
0400         crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
0401 
0402     crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
0403 
0404     scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
0405                  AES_BLOCK_SIZE + tail, 1);
0406     return 0;
0407 }
0408 
0409 static int xts_encrypt(struct skcipher_request *req)
0410 {
0411     return __xts_crypt(req, true, aesbs_xts_encrypt);
0412 }
0413 
0414 static int xts_decrypt(struct skcipher_request *req)
0415 {
0416     return __xts_crypt(req, false, aesbs_xts_decrypt);
0417 }
0418 
0419 static struct skcipher_alg aes_algs[] = { {
0420     .base.cra_name      = "__ecb(aes)",
0421     .base.cra_driver_name   = "__ecb-aes-neonbs",
0422     .base.cra_priority  = 250,
0423     .base.cra_blocksize = AES_BLOCK_SIZE,
0424     .base.cra_ctxsize   = sizeof(struct aesbs_ctx),
0425     .base.cra_module    = THIS_MODULE,
0426     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0427 
0428     .min_keysize        = AES_MIN_KEY_SIZE,
0429     .max_keysize        = AES_MAX_KEY_SIZE,
0430     .walksize       = 8 * AES_BLOCK_SIZE,
0431     .setkey         = aesbs_setkey,
0432     .encrypt        = ecb_encrypt,
0433     .decrypt        = ecb_decrypt,
0434 }, {
0435     .base.cra_name      = "__cbc(aes)",
0436     .base.cra_driver_name   = "__cbc-aes-neonbs",
0437     .base.cra_priority  = 250,
0438     .base.cra_blocksize = AES_BLOCK_SIZE,
0439     .base.cra_ctxsize   = sizeof(struct aesbs_cbc_ctx),
0440     .base.cra_module    = THIS_MODULE,
0441     .base.cra_flags     = CRYPTO_ALG_INTERNAL |
0442                   CRYPTO_ALG_NEED_FALLBACK,
0443 
0444     .min_keysize        = AES_MIN_KEY_SIZE,
0445     .max_keysize        = AES_MAX_KEY_SIZE,
0446     .walksize       = 8 * AES_BLOCK_SIZE,
0447     .ivsize         = AES_BLOCK_SIZE,
0448     .setkey         = aesbs_cbc_setkey,
0449     .encrypt        = cbc_encrypt,
0450     .decrypt        = cbc_decrypt,
0451     .init           = cbc_init,
0452     .exit           = cbc_exit,
0453 }, {
0454     .base.cra_name      = "__ctr(aes)",
0455     .base.cra_driver_name   = "__ctr-aes-neonbs",
0456     .base.cra_priority  = 250,
0457     .base.cra_blocksize = 1,
0458     .base.cra_ctxsize   = sizeof(struct aesbs_ctx),
0459     .base.cra_module    = THIS_MODULE,
0460     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0461 
0462     .min_keysize        = AES_MIN_KEY_SIZE,
0463     .max_keysize        = AES_MAX_KEY_SIZE,
0464     .chunksize      = AES_BLOCK_SIZE,
0465     .walksize       = 8 * AES_BLOCK_SIZE,
0466     .ivsize         = AES_BLOCK_SIZE,
0467     .setkey         = aesbs_setkey,
0468     .encrypt        = ctr_encrypt,
0469     .decrypt        = ctr_encrypt,
0470 }, {
0471     .base.cra_name      = "ctr(aes)",
0472     .base.cra_driver_name   = "ctr-aes-neonbs-sync",
0473     .base.cra_priority  = 250 - 1,
0474     .base.cra_blocksize = 1,
0475     .base.cra_ctxsize   = sizeof(struct aesbs_ctr_ctx),
0476     .base.cra_module    = THIS_MODULE,
0477 
0478     .min_keysize        = AES_MIN_KEY_SIZE,
0479     .max_keysize        = AES_MAX_KEY_SIZE,
0480     .chunksize      = AES_BLOCK_SIZE,
0481     .walksize       = 8 * AES_BLOCK_SIZE,
0482     .ivsize         = AES_BLOCK_SIZE,
0483     .setkey         = aesbs_ctr_setkey_sync,
0484     .encrypt        = ctr_encrypt_sync,
0485     .decrypt        = ctr_encrypt_sync,
0486 }, {
0487     .base.cra_name      = "__xts(aes)",
0488     .base.cra_driver_name   = "__xts-aes-neonbs",
0489     .base.cra_priority  = 250,
0490     .base.cra_blocksize = AES_BLOCK_SIZE,
0491     .base.cra_ctxsize   = sizeof(struct aesbs_xts_ctx),
0492     .base.cra_module    = THIS_MODULE,
0493     .base.cra_flags     = CRYPTO_ALG_INTERNAL,
0494 
0495     .min_keysize        = 2 * AES_MIN_KEY_SIZE,
0496     .max_keysize        = 2 * AES_MAX_KEY_SIZE,
0497     .walksize       = 8 * AES_BLOCK_SIZE,
0498     .ivsize         = AES_BLOCK_SIZE,
0499     .setkey         = aesbs_xts_setkey,
0500     .encrypt        = xts_encrypt,
0501     .decrypt        = xts_decrypt,
0502     .init           = xts_init,
0503     .exit           = xts_exit,
0504 } };
0505 
0506 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
0507 
0508 static void aes_exit(void)
0509 {
0510     int i;
0511 
0512     for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
0513         if (aes_simd_algs[i])
0514             simd_skcipher_free(aes_simd_algs[i]);
0515 
0516     crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0517 }
0518 
0519 static int __init aes_init(void)
0520 {
0521     struct simd_skcipher_alg *simd;
0522     const char *basename;
0523     const char *algname;
0524     const char *drvname;
0525     int err;
0526     int i;
0527 
0528     if (!(elf_hwcap & HWCAP_NEON))
0529         return -ENODEV;
0530 
0531     err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0532     if (err)
0533         return err;
0534 
0535     for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
0536         if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
0537             continue;
0538 
0539         algname = aes_algs[i].base.cra_name + 2;
0540         drvname = aes_algs[i].base.cra_driver_name + 2;
0541         basename = aes_algs[i].base.cra_driver_name;
0542         simd = simd_skcipher_create_compat(algname, drvname, basename);
0543         err = PTR_ERR(simd);
0544         if (IS_ERR(simd))
0545             goto unregister_simds;
0546 
0547         aes_simd_algs[i] = simd;
0548     }
0549     return 0;
0550 
0551 unregister_simds:
0552     aes_exit();
0553     return err;
0554 }
0555 
0556 late_initcall(aes_init);
0557 module_exit(aes_exit);