0001
0002
0003
0004
0005
0006
0007
0008 #include <asm/neon.h>
0009 #include <asm/simd.h>
0010 #include <crypto/aes.h>
0011 #include <crypto/ctr.h>
0012 #include <crypto/internal/simd.h>
0013 #include <crypto/internal/skcipher.h>
0014 #include <crypto/scatterwalk.h>
0015 #include <crypto/xts.h>
0016 #include <linux/module.h>
0017
0018 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
0019 MODULE_LICENSE("GPL v2");
0020
0021 MODULE_ALIAS_CRYPTO("ecb(aes)");
0022 MODULE_ALIAS_CRYPTO("cbc(aes)");
0023 MODULE_ALIAS_CRYPTO("ctr(aes)");
0024 MODULE_ALIAS_CRYPTO("xts(aes)");
0025
0026 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
0027
0028 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
0029 int rounds, int blocks);
0030 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
0031 int rounds, int blocks);
0032
0033 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
0034 int rounds, int blocks, u8 iv[]);
0035
0036 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
0037 int rounds, int blocks, u8 iv[]);
0038
0039 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
0040 int rounds, int blocks, u8 iv[]);
0041 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
0042 int rounds, int blocks, u8 iv[]);
0043
0044
0045 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
0046 int rounds, int blocks);
0047 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
0048 int rounds, int blocks, u8 iv[]);
0049 asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
0050 int rounds, int bytes, u8 ctr[]);
0051 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
0052 u32 const rk1[], int rounds, int bytes,
0053 u32 const rk2[], u8 iv[], int first);
0054 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
0055 u32 const rk1[], int rounds, int bytes,
0056 u32 const rk2[], u8 iv[], int first);
0057
0058 struct aesbs_ctx {
0059 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32];
0060 int rounds;
0061 } __aligned(AES_BLOCK_SIZE);
0062
0063 struct aesbs_cbc_ctr_ctx {
0064 struct aesbs_ctx key;
0065 u32 enc[AES_MAX_KEYLENGTH_U32];
0066 };
0067
0068 struct aesbs_xts_ctx {
0069 struct aesbs_ctx key;
0070 u32 twkey[AES_MAX_KEYLENGTH_U32];
0071 struct crypto_aes_ctx cts;
0072 };
0073
0074 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0075 unsigned int key_len)
0076 {
0077 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0078 struct crypto_aes_ctx rk;
0079 int err;
0080
0081 err = aes_expandkey(&rk, in_key, key_len);
0082 if (err)
0083 return err;
0084
0085 ctx->rounds = 6 + key_len / 4;
0086
0087 kernel_neon_begin();
0088 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
0089 kernel_neon_end();
0090
0091 return 0;
0092 }
0093
0094 static int __ecb_crypt(struct skcipher_request *req,
0095 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0096 int rounds, int blocks))
0097 {
0098 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0099 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
0100 struct skcipher_walk walk;
0101 int err;
0102
0103 err = skcipher_walk_virt(&walk, req, false);
0104
0105 while (walk.nbytes >= AES_BLOCK_SIZE) {
0106 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0107
0108 if (walk.nbytes < walk.total)
0109 blocks = round_down(blocks,
0110 walk.stride / AES_BLOCK_SIZE);
0111
0112 kernel_neon_begin();
0113 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
0114 ctx->rounds, blocks);
0115 kernel_neon_end();
0116 err = skcipher_walk_done(&walk,
0117 walk.nbytes - blocks * AES_BLOCK_SIZE);
0118 }
0119
0120 return err;
0121 }
0122
0123 static int ecb_encrypt(struct skcipher_request *req)
0124 {
0125 return __ecb_crypt(req, aesbs_ecb_encrypt);
0126 }
0127
0128 static int ecb_decrypt(struct skcipher_request *req)
0129 {
0130 return __ecb_crypt(req, aesbs_ecb_decrypt);
0131 }
0132
0133 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0134 unsigned int key_len)
0135 {
0136 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0137 struct crypto_aes_ctx rk;
0138 int err;
0139
0140 err = aes_expandkey(&rk, in_key, key_len);
0141 if (err)
0142 return err;
0143
0144 ctx->key.rounds = 6 + key_len / 4;
0145
0146 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
0147
0148 kernel_neon_begin();
0149 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
0150 kernel_neon_end();
0151 memzero_explicit(&rk, sizeof(rk));
0152
0153 return 0;
0154 }
0155
0156 static int cbc_encrypt(struct skcipher_request *req)
0157 {
0158 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0159 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0160 struct skcipher_walk walk;
0161 int err;
0162
0163 err = skcipher_walk_virt(&walk, req, false);
0164
0165 while (walk.nbytes >= AES_BLOCK_SIZE) {
0166 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0167
0168
0169 kernel_neon_begin();
0170 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
0171 ctx->enc, ctx->key.rounds, blocks,
0172 walk.iv);
0173 kernel_neon_end();
0174 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
0175 }
0176 return err;
0177 }
0178
0179 static int cbc_decrypt(struct skcipher_request *req)
0180 {
0181 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0182 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0183 struct skcipher_walk walk;
0184 int err;
0185
0186 err = skcipher_walk_virt(&walk, req, false);
0187
0188 while (walk.nbytes >= AES_BLOCK_SIZE) {
0189 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
0190
0191 if (walk.nbytes < walk.total)
0192 blocks = round_down(blocks,
0193 walk.stride / AES_BLOCK_SIZE);
0194
0195 kernel_neon_begin();
0196 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
0197 ctx->key.rk, ctx->key.rounds, blocks,
0198 walk.iv);
0199 kernel_neon_end();
0200 err = skcipher_walk_done(&walk,
0201 walk.nbytes - blocks * AES_BLOCK_SIZE);
0202 }
0203
0204 return err;
0205 }
0206
0207 static int ctr_encrypt(struct skcipher_request *req)
0208 {
0209 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0210 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
0211 struct skcipher_walk walk;
0212 int err;
0213
0214 err = skcipher_walk_virt(&walk, req, false);
0215
0216 while (walk.nbytes > 0) {
0217 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
0218 int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE);
0219 const u8 *src = walk.src.virt.addr;
0220 u8 *dst = walk.dst.virt.addr;
0221
0222 kernel_neon_begin();
0223 if (blocks >= 8) {
0224 aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds,
0225 blocks, walk.iv);
0226 dst += blocks * AES_BLOCK_SIZE;
0227 src += blocks * AES_BLOCK_SIZE;
0228 }
0229 if (nbytes && walk.nbytes == walk.total) {
0230 neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds,
0231 nbytes, walk.iv);
0232 nbytes = 0;
0233 }
0234 kernel_neon_end();
0235 err = skcipher_walk_done(&walk, nbytes);
0236 }
0237 return err;
0238 }
0239
0240 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
0241 unsigned int key_len)
0242 {
0243 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0244 struct crypto_aes_ctx rk;
0245 int err;
0246
0247 err = xts_verify_key(tfm, in_key, key_len);
0248 if (err)
0249 return err;
0250
0251 key_len /= 2;
0252 err = aes_expandkey(&ctx->cts, in_key, key_len);
0253 if (err)
0254 return err;
0255
0256 err = aes_expandkey(&rk, in_key + key_len, key_len);
0257 if (err)
0258 return err;
0259
0260 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
0261
0262 return aesbs_setkey(tfm, in_key, key_len);
0263 }
0264
0265 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
0266 void (*fn)(u8 out[], u8 const in[], u8 const rk[],
0267 int rounds, int blocks, u8 iv[]))
0268 {
0269 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0270 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
0271 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
0272 struct scatterlist sg_src[2], sg_dst[2];
0273 struct skcipher_request subreq;
0274 struct scatterlist *src, *dst;
0275 struct skcipher_walk walk;
0276 int nbytes, err;
0277 int first = 1;
0278 u8 *out, *in;
0279
0280 if (req->cryptlen < AES_BLOCK_SIZE)
0281 return -EINVAL;
0282
0283
0284 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
0285 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
0286 AES_BLOCK_SIZE) - 2;
0287
0288 skcipher_request_set_tfm(&subreq, tfm);
0289 skcipher_request_set_callback(&subreq,
0290 skcipher_request_flags(req),
0291 NULL, NULL);
0292 skcipher_request_set_crypt(&subreq, req->src, req->dst,
0293 xts_blocks * AES_BLOCK_SIZE,
0294 req->iv);
0295 req = &subreq;
0296 } else {
0297 tail = 0;
0298 }
0299
0300 err = skcipher_walk_virt(&walk, req, false);
0301 if (err)
0302 return err;
0303
0304 while (walk.nbytes >= AES_BLOCK_SIZE) {
0305 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7;
0306 out = walk.dst.virt.addr;
0307 in = walk.src.virt.addr;
0308 nbytes = walk.nbytes;
0309
0310 kernel_neon_begin();
0311 if (blocks >= 8) {
0312 if (first == 1)
0313 neon_aes_ecb_encrypt(walk.iv, walk.iv,
0314 ctx->twkey,
0315 ctx->key.rounds, 1);
0316 first = 2;
0317
0318 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
0319 walk.iv);
0320
0321 out += blocks * AES_BLOCK_SIZE;
0322 in += blocks * AES_BLOCK_SIZE;
0323 nbytes -= blocks * AES_BLOCK_SIZE;
0324 }
0325 if (walk.nbytes == walk.total && nbytes > 0) {
0326 if (encrypt)
0327 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
0328 ctx->key.rounds, nbytes,
0329 ctx->twkey, walk.iv, first);
0330 else
0331 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
0332 ctx->key.rounds, nbytes,
0333 ctx->twkey, walk.iv, first);
0334 nbytes = first = 0;
0335 }
0336 kernel_neon_end();
0337 err = skcipher_walk_done(&walk, nbytes);
0338 }
0339
0340 if (err || likely(!tail))
0341 return err;
0342
0343
0344 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
0345 if (req->dst != req->src)
0346 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
0347
0348 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
0349 req->iv);
0350
0351 err = skcipher_walk_virt(&walk, req, false);
0352 if (err)
0353 return err;
0354
0355 out = walk.dst.virt.addr;
0356 in = walk.src.virt.addr;
0357 nbytes = walk.nbytes;
0358
0359 kernel_neon_begin();
0360 if (encrypt)
0361 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
0362 nbytes, ctx->twkey, walk.iv, first);
0363 else
0364 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
0365 nbytes, ctx->twkey, walk.iv, first);
0366 kernel_neon_end();
0367
0368 return skcipher_walk_done(&walk, 0);
0369 }
0370
0371 static int xts_encrypt(struct skcipher_request *req)
0372 {
0373 return __xts_crypt(req, true, aesbs_xts_encrypt);
0374 }
0375
0376 static int xts_decrypt(struct skcipher_request *req)
0377 {
0378 return __xts_crypt(req, false, aesbs_xts_decrypt);
0379 }
0380
0381 static struct skcipher_alg aes_algs[] = { {
0382 .base.cra_name = "ecb(aes)",
0383 .base.cra_driver_name = "ecb-aes-neonbs",
0384 .base.cra_priority = 250,
0385 .base.cra_blocksize = AES_BLOCK_SIZE,
0386 .base.cra_ctxsize = sizeof(struct aesbs_ctx),
0387 .base.cra_module = THIS_MODULE,
0388
0389 .min_keysize = AES_MIN_KEY_SIZE,
0390 .max_keysize = AES_MAX_KEY_SIZE,
0391 .walksize = 8 * AES_BLOCK_SIZE,
0392 .setkey = aesbs_setkey,
0393 .encrypt = ecb_encrypt,
0394 .decrypt = ecb_decrypt,
0395 }, {
0396 .base.cra_name = "cbc(aes)",
0397 .base.cra_driver_name = "cbc-aes-neonbs",
0398 .base.cra_priority = 250,
0399 .base.cra_blocksize = AES_BLOCK_SIZE,
0400 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx),
0401 .base.cra_module = THIS_MODULE,
0402
0403 .min_keysize = AES_MIN_KEY_SIZE,
0404 .max_keysize = AES_MAX_KEY_SIZE,
0405 .walksize = 8 * AES_BLOCK_SIZE,
0406 .ivsize = AES_BLOCK_SIZE,
0407 .setkey = aesbs_cbc_ctr_setkey,
0408 .encrypt = cbc_encrypt,
0409 .decrypt = cbc_decrypt,
0410 }, {
0411 .base.cra_name = "ctr(aes)",
0412 .base.cra_driver_name = "ctr-aes-neonbs",
0413 .base.cra_priority = 250,
0414 .base.cra_blocksize = 1,
0415 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx),
0416 .base.cra_module = THIS_MODULE,
0417
0418 .min_keysize = AES_MIN_KEY_SIZE,
0419 .max_keysize = AES_MAX_KEY_SIZE,
0420 .chunksize = AES_BLOCK_SIZE,
0421 .walksize = 8 * AES_BLOCK_SIZE,
0422 .ivsize = AES_BLOCK_SIZE,
0423 .setkey = aesbs_cbc_ctr_setkey,
0424 .encrypt = ctr_encrypt,
0425 .decrypt = ctr_encrypt,
0426 }, {
0427 .base.cra_name = "xts(aes)",
0428 .base.cra_driver_name = "xts-aes-neonbs",
0429 .base.cra_priority = 250,
0430 .base.cra_blocksize = AES_BLOCK_SIZE,
0431 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx),
0432 .base.cra_module = THIS_MODULE,
0433
0434 .min_keysize = 2 * AES_MIN_KEY_SIZE,
0435 .max_keysize = 2 * AES_MAX_KEY_SIZE,
0436 .walksize = 8 * AES_BLOCK_SIZE,
0437 .ivsize = AES_BLOCK_SIZE,
0438 .setkey = aesbs_xts_setkey,
0439 .encrypt = xts_encrypt,
0440 .decrypt = xts_decrypt,
0441 } };
0442
0443 static void aes_exit(void)
0444 {
0445 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0446 }
0447
0448 static int __init aes_init(void)
0449 {
0450 if (!cpu_have_named_feature(ASIMD))
0451 return -ENODEV;
0452
0453 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
0454 }
0455
0456 module_init(aes_init);
0457 module_exit(aes_exit);