Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * Bit sliced AES using NEON instructions
0004  *
0005  * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
0006  */
0007 
0008 #include <asm/neon.h>
0009 #include <asm/simd.h>
0010 #include <crypto/aes.h>
0011 #include <crypto/ctr.h>
0012 #include <crypto/internal/simd.h>
0013 #include <crypto/internal/skcipher.h>
0014 #include <crypto/scatterwalk.h>
0015 #include <crypto/xts.h>
0016 #include <linux/module.h>
0017 
0018 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
0019 MODULE_LICENSE("GPL v2");
0020 
0021 MODULE_ALIAS_CRYPTO("ecb(aes)");
0022 MODULE_ALIAS_CRYPTO("cbc(aes)");
0023 MODULE_ALIAS_CRYPTO("ctr(aes)");
0024 MODULE_ALIAS_CRYPTO("xts(aes)");
0025 
0026 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
0027 
0028 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
0029                   int rounds, int blocks);
0030 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
0031                   int rounds, int blocks);
0032 
0033 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
0034                   int rounds, int blocks, u8 iv[]);
0035 
0036 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
0037                   int rounds, int blocks, u8 iv[]);
0038 
0039 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
0040                   int rounds, int blocks, u8 iv[]);
0041 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
0042                   int rounds, int blocks, u8 iv[]);
0043 
0044 /* borrowed from aes-neon-blk.ko */
0045 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
0046                      int rounds, int blocks);
0047 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
0048                      int rounds, int blocks, u8 iv[]);
0049 asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
0050                      int rounds, int bytes, u8 ctr[]);
0051 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
0052                      u32 const rk1[], int rounds, int bytes,
0053                      u32 const rk2[], u8 iv[], int first);
0054 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
0055                      u32 const rk1[], int rounds, int bytes,
0056                      u32 const rk2[], u8 iv[], int first);
0057 
0058 struct aesbs_ctx {
0059     u8  rk[13 * (8 * AES_BLOCK_SIZE) + 32];
0060     int rounds;
0061 } __aligned(AES_BLOCK_SIZE);
0062 
0063 struct aesbs_cbc_ctr_ctx {
0064     struct aesbs_ctx    key;
0065     u32         enc[AES_MAX_KEYLENGTH_U32];
0066 };
0067 
0068 struct aesbs_xts_ctx {
0069     struct aesbs_ctx    key;
0070     u32         twkey[AES_MAX_KEYLENGTH_U32];
0071     struct crypto_aes_ctx   cts;
0072 };
0073 
0074 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0075             unsigned int key_len)
0076 {
0077     struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0078     struct crypto_aes_ctx rk;
0079     int err;
0080 
0081     err = aes_expandkey(&rk, in_key, key_len);
0082     if (err)
0083         return err;
0084 
0085     ctx->rounds = 6 + key_len / 4;
0086 
0087     kernel_neon_begin();
0088     aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
0089     kernel_neon_end();
0090 
0091     return 0;
0092 }
0093 
0094 static int __ecb_crypt(struct skcipher_request *req,
0095                void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0096                   int rounds, int blocks))
0097 {
0098     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0099     struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0100     struct skcipher_walk walk;
0101     int err;
0102 
0103     err = skcipher_walk_virt(&walk, req, false);
0104 
0105     while (walk.nbytes >= AES_BLOCK_SIZE) {
0106         unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0107 
0108         if (walk.nbytes < walk.total)
0109             blocks = round_down(blocks,
0110                         walk.stride / AES_BLOCK_SIZE);
0111 
0112         kernel_neon_begin();
0113         fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
0114            ctx->rounds, blocks);
0115         kernel_neon_end();
0116         err = skcipher_walk_done(&walk,
0117                      walk.nbytes - blocks * AES_BLOCK_SIZE);
0118     }
0119 
0120     return err;
0121 }
0122 
0123 static int ecb_encrypt(struct skcipher_request *req)
0124 {
0125     return __ecb_crypt(req, aesbs_ecb_encrypt);
0126 }
0127 
0128 static int ecb_decrypt(struct skcipher_request *req)
0129 {
0130     return __ecb_crypt(req, aesbs_ecb_decrypt);
0131 }
0132 
0133 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0134                 unsigned int key_len)
0135 {
0136     struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0137     struct crypto_aes_ctx rk;
0138     int err;
0139 
0140     err = aes_expandkey(&rk, in_key, key_len);
0141     if (err)
0142         return err;
0143 
0144     ctx->key.rounds = 6 + key_len / 4;
0145 
0146     memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
0147 
0148     kernel_neon_begin();
0149     aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
0150     kernel_neon_end();
0151     memzero_explicit(&rk, sizeof(rk));
0152 
0153     return 0;
0154 }
0155 
0156 static int cbc_encrypt(struct skcipher_request *req)
0157 {
0158     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0159     struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0160     struct skcipher_walk walk;
0161     int err;
0162 
0163     err = skcipher_walk_virt(&walk, req, false);
0164 
0165     while (walk.nbytes >= AES_BLOCK_SIZE) {
0166         unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0167 
0168         /* fall back to the non-bitsliced NEON implementation */
0169         kernel_neon_begin();
0170         neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0171                      ctx->enc, ctx->key.rounds, blocks,
0172                      walk.iv);
0173         kernel_neon_end();
0174         err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0175     }
0176     return err;
0177 }
0178 
0179 static int cbc_decrypt(struct skcipher_request *req)
0180 {
0181     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0182     struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0183     struct skcipher_walk walk;
0184     int err;
0185 
0186     err = skcipher_walk_virt(&walk, req, false);
0187 
0188     while (walk.nbytes >= AES_BLOCK_SIZE) {
0189         unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0190 
0191         if (walk.nbytes < walk.total)
0192             blocks = round_down(blocks,
0193                         walk.stride / AES_BLOCK_SIZE);
0194 
0195         kernel_neon_begin();
0196         aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0197                   ctx->key.rk, ctx->key.rounds, blocks,
0198                   walk.iv);
0199         kernel_neon_end();
0200         err = skcipher_walk_done(&walk,
0201                      walk.nbytes - blocks * AES_BLOCK_SIZE);
0202     }
0203 
0204     return err;
0205 }
0206 
0207 static int ctr_encrypt(struct skcipher_request *req)
0208 {
0209     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0210     struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0211     struct skcipher_walk walk;
0212     int err;
0213 
0214     err = skcipher_walk_virt(&walk, req, false);
0215 
0216     while (walk.nbytes > 0) {
0217         int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
0218         int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE);
0219         const u8 *src = walk.src.virt.addr;
0220         u8 *dst = walk.dst.virt.addr;
0221 
0222         kernel_neon_begin();
0223         if (blocks >= 8) {
0224             aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds,
0225                       blocks, walk.iv);
0226             dst += blocks * AES_BLOCK_SIZE;
0227             src += blocks * AES_BLOCK_SIZE;
0228         }
0229         if (nbytes && walk.nbytes == walk.total) {
0230             neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds,
0231                          nbytes, walk.iv);
0232             nbytes = 0;
0233         }
0234         kernel_neon_end();
0235         err = skcipher_walk_done(&walk, nbytes);
0236     }
0237     return err;
0238 }
0239 
0240 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0241                 unsigned int key_len)
0242 {
0243     struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0244     struct crypto_aes_ctx rk;
0245     int err;
0246 
0247     err = xts_verify_key(tfm, in_key, key_len);
0248     if (err)
0249         return err;
0250 
0251     key_len /= 2;
0252     err = aes_expandkey(&ctx->cts, in_key, key_len);
0253     if (err)
0254         return err;
0255 
0256     err = aes_expandkey(&rk, in_key + key_len, key_len);
0257     if (err)
0258         return err;
0259 
0260     memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
0261 
0262     return aesbs_setkey(tfm, in_key, key_len);
0263 }
0264 
0265 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
0266                void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0267                   int rounds, int blocks, u8 iv[]))
0268 {
0269     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0270     struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0271     int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
0272     struct scatterlist sg_src[2], sg_dst[2];
0273     struct skcipher_request subreq;
0274     struct scatterlist *src, *dst;
0275     struct skcipher_walk walk;
0276     int nbytes, err;
0277     int first = 1;
0278     u8 *out, *in;
0279 
0280     if (req->cryptlen < AES_BLOCK_SIZE)
0281         return -EINVAL;
0282 
0283     /* ensure that the cts tail is covered by a single step */
0284     if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
0285         int xts_blocks = DIV_ROUND_UP(req->cryptlen,
0286                           AES_BLOCK_SIZE) - 2;
0287 
0288         skcipher_request_set_tfm(&subreq, tfm);
0289         skcipher_request_set_callback(&subreq,
0290                           skcipher_request_flags(req),
0291                           NULL, NULL);
0292         skcipher_request_set_crypt(&subreq, req->src, req->dst,
0293                        xts_blocks * AES_BLOCK_SIZE,
0294                        req->iv);
0295         req = &subreq;
0296     } else {
0297         tail = 0;
0298     }
0299 
0300     err = skcipher_walk_virt(&walk, req, false);
0301     if (err)
0302         return err;
0303 
0304     while (walk.nbytes >= AES_BLOCK_SIZE) {
0305         int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
0306         out = walk.dst.virt.addr;
0307         in = walk.src.virt.addr;
0308         nbytes = walk.nbytes;
0309 
0310         kernel_neon_begin();
0311         if (blocks >= 8) {
0312             if (first == 1)
0313                 neon_aes_ecb_encrypt(walk.iv, walk.iv,
0314                              ctx->twkey,
0315                              ctx->key.rounds, 1);
0316             first = 2;
0317 
0318             fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
0319                walk.iv);
0320 
0321             out += blocks * AES_BLOCK_SIZE;
0322             in += blocks * AES_BLOCK_SIZE;
0323             nbytes -= blocks * AES_BLOCK_SIZE;
0324         }
0325         if (walk.nbytes == walk.total && nbytes > 0) {
0326             if (encrypt)
0327                 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
0328                              ctx->key.rounds, nbytes,
0329                              ctx->twkey, walk.iv, first);
0330             else
0331                 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
0332                              ctx->key.rounds, nbytes,
0333                              ctx->twkey, walk.iv, first);
0334             nbytes = first = 0;
0335         }
0336         kernel_neon_end();
0337         err = skcipher_walk_done(&walk, nbytes);
0338     }
0339 
0340     if (err || likely(!tail))
0341         return err;
0342 
0343     /* handle ciphertext stealing */
0344     dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
0345     if (req->dst != req->src)
0346         dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
0347 
0348     skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
0349                    req->iv);
0350 
0351     err = skcipher_walk_virt(&walk, req, false);
0352     if (err)
0353         return err;
0354 
0355     out = walk.dst.virt.addr;
0356     in = walk.src.virt.addr;
0357     nbytes = walk.nbytes;
0358 
0359     kernel_neon_begin();
0360     if (encrypt)
0361         neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
0362                      nbytes, ctx->twkey, walk.iv, first);
0363     else
0364         neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
0365                      nbytes, ctx->twkey, walk.iv, first);
0366     kernel_neon_end();
0367 
0368     return skcipher_walk_done(&walk, 0);
0369 }
0370 
0371 static int xts_encrypt(struct skcipher_request *req)
0372 {
0373     return __xts_crypt(req, true, aesbs_xts_encrypt);
0374 }
0375 
0376 static int xts_decrypt(struct skcipher_request *req)
0377 {
0378     return __xts_crypt(req, false, aesbs_xts_decrypt);
0379 }
0380 
0381 static struct skcipher_alg aes_algs[] = { {
0382     .base.cra_name      = "ecb(aes)",
0383     .base.cra_driver_name   = "ecb-aes-neonbs",
0384     .base.cra_priority  = 250,
0385     .base.cra_blocksize = AES_BLOCK_SIZE,
0386     .base.cra_ctxsize   = sizeof(struct aesbs_ctx),
0387     .base.cra_module    = THIS_MODULE,
0388 
0389     .min_keysize        = AES_MIN_KEY_SIZE,
0390     .max_keysize        = AES_MAX_KEY_SIZE,
0391     .walksize       = 8 * AES_BLOCK_SIZE,
0392     .setkey         = aesbs_setkey,
0393     .encrypt        = ecb_encrypt,
0394     .decrypt        = ecb_decrypt,
0395 }, {
0396     .base.cra_name      = "cbc(aes)",
0397     .base.cra_driver_name   = "cbc-aes-neonbs",
0398     .base.cra_priority  = 250,
0399     .base.cra_blocksize = AES_BLOCK_SIZE,
0400     .base.cra_ctxsize   = sizeof(struct aesbs_cbc_ctr_ctx),
0401     .base.cra_module    = THIS_MODULE,
0402 
0403     .min_keysize        = AES_MIN_KEY_SIZE,
0404     .max_keysize        = AES_MAX_KEY_SIZE,
0405     .walksize       = 8 * AES_BLOCK_SIZE,
0406     .ivsize         = AES_BLOCK_SIZE,
0407     .setkey         = aesbs_cbc_ctr_setkey,
0408     .encrypt        = cbc_encrypt,
0409     .decrypt        = cbc_decrypt,
0410 }, {
0411     .base.cra_name      = "ctr(aes)",
0412     .base.cra_driver_name   = "ctr-aes-neonbs",
0413     .base.cra_priority  = 250,
0414     .base.cra_blocksize = 1,
0415     .base.cra_ctxsize   = sizeof(struct aesbs_cbc_ctr_ctx),
0416     .base.cra_module    = THIS_MODULE,
0417 
0418     .min_keysize        = AES_MIN_KEY_SIZE,
0419     .max_keysize        = AES_MAX_KEY_SIZE,
0420     .chunksize      = AES_BLOCK_SIZE,
0421     .walksize       = 8 * AES_BLOCK_SIZE,
0422     .ivsize         = AES_BLOCK_SIZE,
0423     .setkey         = aesbs_cbc_ctr_setkey,
0424     .encrypt        = ctr_encrypt,
0425     .decrypt        = ctr_encrypt,
0426 }, {
0427     .base.cra_name      = "xts(aes)",
0428     .base.cra_driver_name   = "xts-aes-neonbs",
0429     .base.cra_priority  = 250,
0430     .base.cra_blocksize = AES_BLOCK_SIZE,
0431     .base.cra_ctxsize   = sizeof(struct aesbs_xts_ctx),
0432     .base.cra_module    = THIS_MODULE,
0433 
0434     .min_keysize        = 2 * AES_MIN_KEY_SIZE,
0435     .max_keysize        = 2 * AES_MAX_KEY_SIZE,
0436     .walksize       = 8 * AES_BLOCK_SIZE,
0437     .ivsize         = AES_BLOCK_SIZE,
0438     .setkey         = aesbs_xts_setkey,
0439     .encrypt        = xts_encrypt,
0440     .decrypt        = xts_decrypt,
0441 } };
0442 
0443 static void aes_exit(void)
0444 {
0445     crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0446 }
0447 
0448 static int __init aes_init(void)
0449 {
0450     if (!cpu_have_named_feature(ASIMD))
0451         return -ENODEV;
0452 
0453     return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0454 }
0455 
0456 module_init(aes_init);
0457 module_exit(aes_exit);