0001
0002
0003
0004
0005
0006
0007
0008 #include <asm/neon.h>
0009 #include <asm/simd.h>
0010 #include <crypto/aes.h>
0011 #include <crypto/ctr.h>
0012 #include <crypto/internal/cipher.h>
0013 #include <crypto/internal/simd.h>
0014 #include <crypto/internal/skcipher.h>
0015 #include <crypto/scatterwalk.h>
0016 #include <crypto/xts.h>
0017 #include <linux/module.h>
0018
0019 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
0020 MODULE_LICENSE("GPL v2");
0021
0022 MODULE_ALIAS_CRYPTO("ecb(aes)");
0023 MODULE_ALIAS_CRYPTO("cbc(aes)-all");
0024 MODULE_ALIAS_CRYPTO("ctr(aes)");
0025 MODULE_ALIAS_CRYPTO("xts(aes)");
0026
0027 MODULE_IMPORT_NS(CRYPTO_INTERNAL);
0028
0029 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
0030
0031 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
0032 int rounds, int blocks);
0033 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
0034 int rounds, int blocks);
0035
0036 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
0037 int rounds, int blocks, u8 iv[]);
0038
0039 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
0040 int rounds, int blocks, u8 ctr[]);
0041
0042 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
0043 int rounds, int blocks, u8 iv[], int);
0044 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
0045 int rounds, int blocks, u8 iv[], int);
0046
0047 struct aesbs_ctx {
0048 int rounds;
0049 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
0050 };
0051
0052 struct aesbs_cbc_ctx {
0053 struct aesbs_ctx key;
0054 struct crypto_skcipher *enc_tfm;
0055 };
0056
0057 struct aesbs_xts_ctx {
0058 struct aesbs_ctx key;
0059 struct crypto_cipher *cts_tfm;
0060 struct crypto_cipher *tweak_tfm;
0061 };
0062
0063 struct aesbs_ctr_ctx {
0064 struct aesbs_ctx key;
0065 struct crypto_aes_ctx fallback;
0066 };
0067
0068 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0069 unsigned int key_len)
0070 {
0071 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0072 struct crypto_aes_ctx rk;
0073 int err;
0074
0075 err = aes_expandkey(&rk, in_key, key_len);
0076 if (err)
0077 return err;
0078
0079 ctx->rounds = 6 + key_len / 4;
0080
0081 kernel_neon_begin();
0082 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
0083 kernel_neon_end();
0084
0085 return 0;
0086 }
0087
0088 static int __ecb_crypt(struct skcipher_request *req,
0089 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0090 int rounds, int blocks))
0091 {
0092 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0093 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0094 struct skcipher_walk walk;
0095 int err;
0096
0097 err = skcipher_walk_virt(&walk, req, false);
0098
0099 while (walk.nbytes >= AES_BLOCK_SIZE) {
0100 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0101
0102 if (walk.nbytes < walk.total)
0103 blocks = round_down(blocks,
0104 walk.stride / AES_BLOCK_SIZE);
0105
0106 kernel_neon_begin();
0107 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
0108 ctx->rounds, blocks);
0109 kernel_neon_end();
0110 err = skcipher_walk_done(&walk,
0111 walk.nbytes - blocks * AES_BLOCK_SIZE);
0112 }
0113
0114 return err;
0115 }
0116
0117 static int ecb_encrypt(struct skcipher_request *req)
0118 {
0119 return __ecb_crypt(req, aesbs_ecb_encrypt);
0120 }
0121
0122 static int ecb_decrypt(struct skcipher_request *req)
0123 {
0124 return __ecb_crypt(req, aesbs_ecb_decrypt);
0125 }
0126
0127 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0128 unsigned int key_len)
0129 {
0130 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0131 struct crypto_aes_ctx rk;
0132 int err;
0133
0134 err = aes_expandkey(&rk, in_key, key_len);
0135 if (err)
0136 return err;
0137
0138 ctx->key.rounds = 6 + key_len / 4;
0139
0140 kernel_neon_begin();
0141 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
0142 kernel_neon_end();
0143 memzero_explicit(&rk, sizeof(rk));
0144
0145 return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
0146 }
0147
0148 static int cbc_encrypt(struct skcipher_request *req)
0149 {
0150 struct skcipher_request *subreq = skcipher_request_ctx(req);
0151 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0152 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0153
0154 skcipher_request_set_tfm(subreq, ctx->enc_tfm);
0155 skcipher_request_set_callback(subreq,
0156 skcipher_request_flags(req),
0157 NULL, NULL);
0158 skcipher_request_set_crypt(subreq, req->src, req->dst,
0159 req->cryptlen, req->iv);
0160
0161 return crypto_skcipher_encrypt(subreq);
0162 }
0163
0164 static int cbc_decrypt(struct skcipher_request *req)
0165 {
0166 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0167 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0168 struct skcipher_walk walk;
0169 int err;
0170
0171 err = skcipher_walk_virt(&walk, req, false);
0172
0173 while (walk.nbytes >= AES_BLOCK_SIZE) {
0174 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0175
0176 if (walk.nbytes < walk.total)
0177 blocks = round_down(blocks,
0178 walk.stride / AES_BLOCK_SIZE);
0179
0180 kernel_neon_begin();
0181 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0182 ctx->key.rk, ctx->key.rounds, blocks,
0183 walk.iv);
0184 kernel_neon_end();
0185 err = skcipher_walk_done(&walk,
0186 walk.nbytes - blocks * AES_BLOCK_SIZE);
0187 }
0188
0189 return err;
0190 }
0191
0192 static int cbc_init(struct crypto_skcipher *tfm)
0193 {
0194 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0195 unsigned int reqsize;
0196
0197 ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
0198 CRYPTO_ALG_NEED_FALLBACK);
0199 if (IS_ERR(ctx->enc_tfm))
0200 return PTR_ERR(ctx->enc_tfm);
0201
0202 reqsize = sizeof(struct skcipher_request);
0203 reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
0204 crypto_skcipher_set_reqsize(tfm, reqsize);
0205
0206 return 0;
0207 }
0208
0209 static void cbc_exit(struct crypto_skcipher *tfm)
0210 {
0211 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
0212
0213 crypto_free_skcipher(ctx->enc_tfm);
0214 }
0215
0216 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
0217 unsigned int key_len)
0218 {
0219 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0220 int err;
0221
0222 err = aes_expandkey(&ctx->fallback, in_key, key_len);
0223 if (err)
0224 return err;
0225
0226 ctx->key.rounds = 6 + key_len / 4;
0227
0228 kernel_neon_begin();
0229 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
0230 kernel_neon_end();
0231
0232 return 0;
0233 }
0234
0235 static int ctr_encrypt(struct skcipher_request *req)
0236 {
0237 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0238 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0239 struct skcipher_walk walk;
0240 u8 buf[AES_BLOCK_SIZE];
0241 int err;
0242
0243 err = skcipher_walk_virt(&walk, req, false);
0244
0245 while (walk.nbytes > 0) {
0246 const u8 *src = walk.src.virt.addr;
0247 u8 *dst = walk.dst.virt.addr;
0248 int bytes = walk.nbytes;
0249
0250 if (unlikely(bytes < AES_BLOCK_SIZE))
0251 src = dst = memcpy(buf + sizeof(buf) - bytes,
0252 src, bytes);
0253 else if (walk.nbytes < walk.total)
0254 bytes &= ~(8 * AES_BLOCK_SIZE - 1);
0255
0256 kernel_neon_begin();
0257 aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
0258 kernel_neon_end();
0259
0260 if (unlikely(bytes < AES_BLOCK_SIZE))
0261 memcpy(walk.dst.virt.addr,
0262 buf + sizeof(buf) - bytes, bytes);
0263
0264 err = skcipher_walk_done(&walk, walk.nbytes - bytes);
0265 }
0266
0267 return err;
0268 }
0269
0270 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
0271 {
0272 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0273 unsigned long flags;
0274
0275
0276
0277
0278
0279
0280 local_irq_save(flags);
0281 aes_encrypt(&ctx->fallback, dst, src);
0282 local_irq_restore(flags);
0283 }
0284
0285 static int ctr_encrypt_sync(struct skcipher_request *req)
0286 {
0287 if (!crypto_simd_usable())
0288 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
0289
0290 return ctr_encrypt(req);
0291 }
0292
0293 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0294 unsigned int key_len)
0295 {
0296 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0297 int err;
0298
0299 err = xts_verify_key(tfm, in_key, key_len);
0300 if (err)
0301 return err;
0302
0303 key_len /= 2;
0304 err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
0305 if (err)
0306 return err;
0307 err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
0308 if (err)
0309 return err;
0310
0311 return aesbs_setkey(tfm, in_key, key_len);
0312 }
0313
0314 static int xts_init(struct crypto_skcipher *tfm)
0315 {
0316 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0317
0318 ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
0319 if (IS_ERR(ctx->cts_tfm))
0320 return PTR_ERR(ctx->cts_tfm);
0321
0322 ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
0323 if (IS_ERR(ctx->tweak_tfm))
0324 crypto_free_cipher(ctx->cts_tfm);
0325
0326 return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
0327 }
0328
0329 static void xts_exit(struct crypto_skcipher *tfm)
0330 {
0331 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0332
0333 crypto_free_cipher(ctx->tweak_tfm);
0334 crypto_free_cipher(ctx->cts_tfm);
0335 }
0336
0337 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
0338 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0339 int rounds, int blocks, u8 iv[], int))
0340 {
0341 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0342 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0343 int tail = req->cryptlen % AES_BLOCK_SIZE;
0344 struct skcipher_request subreq;
0345 u8 buf[2 * AES_BLOCK_SIZE];
0346 struct skcipher_walk walk;
0347 int err;
0348
0349 if (req->cryptlen < AES_BLOCK_SIZE)
0350 return -EINVAL;
0351
0352 if (unlikely(tail)) {
0353 skcipher_request_set_tfm(&subreq, tfm);
0354 skcipher_request_set_callback(&subreq,
0355 skcipher_request_flags(req),
0356 NULL, NULL);
0357 skcipher_request_set_crypt(&subreq, req->src, req->dst,
0358 req->cryptlen - tail, req->iv);
0359 req = &subreq;
0360 }
0361
0362 err = skcipher_walk_virt(&walk, req, true);
0363 if (err)
0364 return err;
0365
0366 crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
0367
0368 while (walk.nbytes >= AES_BLOCK_SIZE) {
0369 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0370 int reorder_last_tweak = !encrypt && tail > 0;
0371
0372 if (walk.nbytes < walk.total) {
0373 blocks = round_down(blocks,
0374 walk.stride / AES_BLOCK_SIZE);
0375 reorder_last_tweak = 0;
0376 }
0377
0378 kernel_neon_begin();
0379 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
0380 ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
0381 kernel_neon_end();
0382 err = skcipher_walk_done(&walk,
0383 walk.nbytes - blocks * AES_BLOCK_SIZE);
0384 }
0385
0386 if (err || likely(!tail))
0387 return err;
0388
0389
0390 scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
0391 AES_BLOCK_SIZE, 0);
0392 memcpy(buf + AES_BLOCK_SIZE, buf, tail);
0393 scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
0394
0395 crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
0396
0397 if (encrypt)
0398 crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
0399 else
0400 crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
0401
0402 crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
0403
0404 scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
0405 AES_BLOCK_SIZE + tail, 1);
0406 return 0;
0407 }
0408
0409 static int xts_encrypt(struct skcipher_request *req)
0410 {
0411 return __xts_crypt(req, true, aesbs_xts_encrypt);
0412 }
0413
0414 static int xts_decrypt(struct skcipher_request *req)
0415 {
0416 return __xts_crypt(req, false, aesbs_xts_decrypt);
0417 }
0418
0419 static struct skcipher_alg aes_algs[] = { {
0420 .base.cra_name = "__ecb(aes)",
0421 .base.cra_driver_name = "__ecb-aes-neonbs",
0422 .base.cra_priority = 250,
0423 .base.cra_blocksize = AES_BLOCK_SIZE,
0424 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
0425 .base.cra_module = THIS_MODULE,
0426 .base.cra_flags = CRYPTO_ALG_INTERNAL,
0427
0428 .min_keysize = AES_MIN_KEY_SIZE,
0429 .max_keysize = AES_MAX_KEY_SIZE,
0430 .walksize = 8 * AES_BLOCK_SIZE,
0431 .setkey = aesbs_setkey,
0432 .encrypt = ecb_encrypt,
0433 .decrypt = ecb_decrypt,
0434 }, {
0435 .base.cra_name = "__cbc(aes)",
0436 .base.cra_driver_name = "__cbc-aes-neonbs",
0437 .base.cra_priority = 250,
0438 .base.cra_blocksize = AES_BLOCK_SIZE,
0439 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx),
0440 .base.cra_module = THIS_MODULE,
0441 .base.cra_flags = CRYPTO_ALG_INTERNAL |
0442 CRYPTO_ALG_NEED_FALLBACK,
0443
0444 .min_keysize = AES_MIN_KEY_SIZE,
0445 .max_keysize = AES_MAX_KEY_SIZE,
0446 .walksize = 8 * AES_BLOCK_SIZE,
0447 .ivsize = AES_BLOCK_SIZE,
0448 .setkey = aesbs_cbc_setkey,
0449 .encrypt = cbc_encrypt,
0450 .decrypt = cbc_decrypt,
0451 .init = cbc_init,
0452 .exit = cbc_exit,
0453 }, {
0454 .base.cra_name = "__ctr(aes)",
0455 .base.cra_driver_name = "__ctr-aes-neonbs",
0456 .base.cra_priority = 250,
0457 .base.cra_blocksize = 1,
0458 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
0459 .base.cra_module = THIS_MODULE,
0460 .base.cra_flags = CRYPTO_ALG_INTERNAL,
0461
0462 .min_keysize = AES_MIN_KEY_SIZE,
0463 .max_keysize = AES_MAX_KEY_SIZE,
0464 .chunksize = AES_BLOCK_SIZE,
0465 .walksize = 8 * AES_BLOCK_SIZE,
0466 .ivsize = AES_BLOCK_SIZE,
0467 .setkey = aesbs_setkey,
0468 .encrypt = ctr_encrypt,
0469 .decrypt = ctr_encrypt,
0470 }, {
0471 .base.cra_name = "ctr(aes)",
0472 .base.cra_driver_name = "ctr-aes-neonbs-sync",
0473 .base.cra_priority = 250 - 1,
0474 .base.cra_blocksize = 1,
0475 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx),
0476 .base.cra_module = THIS_MODULE,
0477
0478 .min_keysize = AES_MIN_KEY_SIZE,
0479 .max_keysize = AES_MAX_KEY_SIZE,
0480 .chunksize = AES_BLOCK_SIZE,
0481 .walksize = 8 * AES_BLOCK_SIZE,
0482 .ivsize = AES_BLOCK_SIZE,
0483 .setkey = aesbs_ctr_setkey_sync,
0484 .encrypt = ctr_encrypt_sync,
0485 .decrypt = ctr_encrypt_sync,
0486 }, {
0487 .base.cra_name = "__xts(aes)",
0488 .base.cra_driver_name = "__xts-aes-neonbs",
0489 .base.cra_priority = 250,
0490 .base.cra_blocksize = AES_BLOCK_SIZE,
0491 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
0492 .base.cra_module = THIS_MODULE,
0493 .base.cra_flags = CRYPTO_ALG_INTERNAL,
0494
0495 .min_keysize = 2 * AES_MIN_KEY_SIZE,
0496 .max_keysize = 2 * AES_MAX_KEY_SIZE,
0497 .walksize = 8 * AES_BLOCK_SIZE,
0498 .ivsize = AES_BLOCK_SIZE,
0499 .setkey = aesbs_xts_setkey,
0500 .encrypt = xts_encrypt,
0501 .decrypt = xts_decrypt,
0502 .init = xts_init,
0503 .exit = xts_exit,
0504 } };
0505
0506 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
0507
0508 static void aes_exit(void)
0509 {
0510 int i;
0511
0512 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
0513 if (aes_simd_algs[i])
0514 simd_skcipher_free(aes_simd_algs[i]);
0515
0516 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0517 }
0518
0519 static int __init aes_init(void)
0520 {
0521 struct simd_skcipher_alg *simd;
0522 const char *basename;
0523 const char *algname;
0524 const char *drvname;
0525 int err;
0526 int i;
0527
0528 if (!(elf_hwcap & HWCAP_NEON))
0529 return -ENODEV;
0530
0531 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0532 if (err)
0533 return err;
0534
0535 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
0536 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
0537 continue;
0538
0539 algname = aes_algs[i].base.cra_name + 2;
0540 drvname = aes_algs[i].base.cra_driver_name + 2;
0541 basename = aes_algs[i].base.cra_driver_name;
0542 simd = simd_skcipher_create_compat(algname, drvname, basename);
0543 err = PTR_ERR(simd);
0544 if (IS_ERR(simd))
0545 goto unregister_simds;
0546
0547 aes_simd_algs[i] = simd;
0548 }
0549 return 0;
0550
0551 unregister_simds:
0552 aes_exit();
0553 return err;
0554 }
0555
0556 late_initcall(aes_init);
0557 module_exit(aes_exit);