0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #include <linux/module.h>
0012 #include <linux/crypto.h>
0013 #include <linux/kernel.h>
0014 #include <asm/simd.h>
0015 #include <crypto/internal/simd.h>
0016 #include <crypto/internal/skcipher.h>
0017 #include <crypto/sm4.h>
0018 #include "sm4-avx.h"
0019
0020 #define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8)
0021
0022 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
0023 const u8 *src, int nblocks);
0024 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
0025 const u8 *src, int nblocks);
0026 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
0027 const u8 *src, u8 *iv);
0028 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
0029 const u8 *src, u8 *iv);
0030 asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
0031 const u8 *src, u8 *iv);
0032
0033 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
0034 unsigned int key_len)
0035 {
0036 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0037
0038 return sm4_expandkey(ctx, key, key_len);
0039 }
0040
0041 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
0042 {
0043 struct skcipher_walk walk;
0044 unsigned int nbytes;
0045 int err;
0046
0047 err = skcipher_walk_virt(&walk, req, false);
0048
0049 while ((nbytes = walk.nbytes) > 0) {
0050 const u8 *src = walk.src.virt.addr;
0051 u8 *dst = walk.dst.virt.addr;
0052
0053 kernel_fpu_begin();
0054 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
0055 sm4_aesni_avx_crypt8(rkey, dst, src, 8);
0056 dst += SM4_CRYPT8_BLOCK_SIZE;
0057 src += SM4_CRYPT8_BLOCK_SIZE;
0058 nbytes -= SM4_CRYPT8_BLOCK_SIZE;
0059 }
0060 while (nbytes >= SM4_BLOCK_SIZE) {
0061 unsigned int nblocks = min(nbytes >> 4, 4u);
0062 sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
0063 dst += nblocks * SM4_BLOCK_SIZE;
0064 src += nblocks * SM4_BLOCK_SIZE;
0065 nbytes -= nblocks * SM4_BLOCK_SIZE;
0066 }
0067 kernel_fpu_end();
0068
0069 err = skcipher_walk_done(&walk, nbytes);
0070 }
0071
0072 return err;
0073 }
0074
0075 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
0076 {
0077 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0078 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0079
0080 return ecb_do_crypt(req, ctx->rkey_enc);
0081 }
0082 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
0083
0084 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
0085 {
0086 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0087 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0088
0089 return ecb_do_crypt(req, ctx->rkey_dec);
0090 }
0091 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
0092
0093 int sm4_cbc_encrypt(struct skcipher_request *req)
0094 {
0095 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0096 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0097 struct skcipher_walk walk;
0098 unsigned int nbytes;
0099 int err;
0100
0101 err = skcipher_walk_virt(&walk, req, false);
0102
0103 while ((nbytes = walk.nbytes) > 0) {
0104 const u8 *iv = walk.iv;
0105 const u8 *src = walk.src.virt.addr;
0106 u8 *dst = walk.dst.virt.addr;
0107
0108 while (nbytes >= SM4_BLOCK_SIZE) {
0109 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
0110 sm4_crypt_block(ctx->rkey_enc, dst, dst);
0111 iv = dst;
0112 src += SM4_BLOCK_SIZE;
0113 dst += SM4_BLOCK_SIZE;
0114 nbytes -= SM4_BLOCK_SIZE;
0115 }
0116 if (iv != walk.iv)
0117 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0118
0119 err = skcipher_walk_done(&walk, nbytes);
0120 }
0121
0122 return err;
0123 }
0124 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
0125
0126 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
0127 unsigned int bsize, sm4_crypt_func func)
0128 {
0129 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0130 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0131 struct skcipher_walk walk;
0132 unsigned int nbytes;
0133 int err;
0134
0135 err = skcipher_walk_virt(&walk, req, false);
0136
0137 while ((nbytes = walk.nbytes) > 0) {
0138 const u8 *src = walk.src.virt.addr;
0139 u8 *dst = walk.dst.virt.addr;
0140
0141 kernel_fpu_begin();
0142
0143 while (nbytes >= bsize) {
0144 func(ctx->rkey_dec, dst, src, walk.iv);
0145 dst += bsize;
0146 src += bsize;
0147 nbytes -= bsize;
0148 }
0149
0150 while (nbytes >= SM4_BLOCK_SIZE) {
0151 u8 keystream[SM4_BLOCK_SIZE * 8];
0152 u8 iv[SM4_BLOCK_SIZE];
0153 unsigned int nblocks = min(nbytes >> 4, 8u);
0154 int i;
0155
0156 sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
0157 src, nblocks);
0158
0159 src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
0160 dst += (nblocks - 1) * SM4_BLOCK_SIZE;
0161 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
0162
0163 for (i = nblocks - 1; i > 0; i--) {
0164 crypto_xor_cpy(dst, src,
0165 &keystream[i * SM4_BLOCK_SIZE],
0166 SM4_BLOCK_SIZE);
0167 src -= SM4_BLOCK_SIZE;
0168 dst -= SM4_BLOCK_SIZE;
0169 }
0170 crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
0171 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0172 dst += nblocks * SM4_BLOCK_SIZE;
0173 src += (nblocks + 1) * SM4_BLOCK_SIZE;
0174 nbytes -= nblocks * SM4_BLOCK_SIZE;
0175 }
0176
0177 kernel_fpu_end();
0178 err = skcipher_walk_done(&walk, nbytes);
0179 }
0180
0181 return err;
0182 }
0183 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
0184
0185 static int cbc_decrypt(struct skcipher_request *req)
0186 {
0187 return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
0188 sm4_aesni_avx_cbc_dec_blk8);
0189 }
0190
0191 int sm4_cfb_encrypt(struct skcipher_request *req)
0192 {
0193 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0194 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0195 struct skcipher_walk walk;
0196 unsigned int nbytes;
0197 int err;
0198
0199 err = skcipher_walk_virt(&walk, req, false);
0200
0201 while ((nbytes = walk.nbytes) > 0) {
0202 u8 keystream[SM4_BLOCK_SIZE];
0203 const u8 *iv = walk.iv;
0204 const u8 *src = walk.src.virt.addr;
0205 u8 *dst = walk.dst.virt.addr;
0206
0207 while (nbytes >= SM4_BLOCK_SIZE) {
0208 sm4_crypt_block(ctx->rkey_enc, keystream, iv);
0209 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
0210 iv = dst;
0211 src += SM4_BLOCK_SIZE;
0212 dst += SM4_BLOCK_SIZE;
0213 nbytes -= SM4_BLOCK_SIZE;
0214 }
0215 if (iv != walk.iv)
0216 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0217
0218
0219 if (walk.nbytes == walk.total && nbytes > 0) {
0220 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0221 crypto_xor_cpy(dst, src, keystream, nbytes);
0222 nbytes = 0;
0223 }
0224
0225 err = skcipher_walk_done(&walk, nbytes);
0226 }
0227
0228 return err;
0229 }
0230 EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
0231
0232 int sm4_avx_cfb_decrypt(struct skcipher_request *req,
0233 unsigned int bsize, sm4_crypt_func func)
0234 {
0235 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0236 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0237 struct skcipher_walk walk;
0238 unsigned int nbytes;
0239 int err;
0240
0241 err = skcipher_walk_virt(&walk, req, false);
0242
0243 while ((nbytes = walk.nbytes) > 0) {
0244 const u8 *src = walk.src.virt.addr;
0245 u8 *dst = walk.dst.virt.addr;
0246
0247 kernel_fpu_begin();
0248
0249 while (nbytes >= bsize) {
0250 func(ctx->rkey_enc, dst, src, walk.iv);
0251 dst += bsize;
0252 src += bsize;
0253 nbytes -= bsize;
0254 }
0255
0256 while (nbytes >= SM4_BLOCK_SIZE) {
0257 u8 keystream[SM4_BLOCK_SIZE * 8];
0258 unsigned int nblocks = min(nbytes >> 4, 8u);
0259
0260 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
0261 if (nblocks > 1)
0262 memcpy(&keystream[SM4_BLOCK_SIZE], src,
0263 (nblocks - 1) * SM4_BLOCK_SIZE);
0264 memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
0265 SM4_BLOCK_SIZE);
0266
0267 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
0268 keystream, nblocks);
0269
0270 crypto_xor_cpy(dst, src, keystream,
0271 nblocks * SM4_BLOCK_SIZE);
0272 dst += nblocks * SM4_BLOCK_SIZE;
0273 src += nblocks * SM4_BLOCK_SIZE;
0274 nbytes -= nblocks * SM4_BLOCK_SIZE;
0275 }
0276
0277 kernel_fpu_end();
0278
0279
0280 if (walk.nbytes == walk.total && nbytes > 0) {
0281 u8 keystream[SM4_BLOCK_SIZE];
0282
0283 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0284 crypto_xor_cpy(dst, src, keystream, nbytes);
0285 nbytes = 0;
0286 }
0287
0288 err = skcipher_walk_done(&walk, nbytes);
0289 }
0290
0291 return err;
0292 }
0293 EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
0294
0295 static int cfb_decrypt(struct skcipher_request *req)
0296 {
0297 return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
0298 sm4_aesni_avx_cfb_dec_blk8);
0299 }
0300
0301 int sm4_avx_ctr_crypt(struct skcipher_request *req,
0302 unsigned int bsize, sm4_crypt_func func)
0303 {
0304 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0305 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0306 struct skcipher_walk walk;
0307 unsigned int nbytes;
0308 int err;
0309
0310 err = skcipher_walk_virt(&walk, req, false);
0311
0312 while ((nbytes = walk.nbytes) > 0) {
0313 const u8 *src = walk.src.virt.addr;
0314 u8 *dst = walk.dst.virt.addr;
0315
0316 kernel_fpu_begin();
0317
0318 while (nbytes >= bsize) {
0319 func(ctx->rkey_enc, dst, src, walk.iv);
0320 dst += bsize;
0321 src += bsize;
0322 nbytes -= bsize;
0323 }
0324
0325 while (nbytes >= SM4_BLOCK_SIZE) {
0326 u8 keystream[SM4_BLOCK_SIZE * 8];
0327 unsigned int nblocks = min(nbytes >> 4, 8u);
0328 int i;
0329
0330 for (i = 0; i < nblocks; i++) {
0331 memcpy(&keystream[i * SM4_BLOCK_SIZE],
0332 walk.iv, SM4_BLOCK_SIZE);
0333 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0334 }
0335 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
0336 keystream, nblocks);
0337
0338 crypto_xor_cpy(dst, src, keystream,
0339 nblocks * SM4_BLOCK_SIZE);
0340 dst += nblocks * SM4_BLOCK_SIZE;
0341 src += nblocks * SM4_BLOCK_SIZE;
0342 nbytes -= nblocks * SM4_BLOCK_SIZE;
0343 }
0344
0345 kernel_fpu_end();
0346
0347
0348 if (walk.nbytes == walk.total && nbytes > 0) {
0349 u8 keystream[SM4_BLOCK_SIZE];
0350
0351 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
0352 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0353
0354 sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
0355
0356 crypto_xor_cpy(dst, src, keystream, nbytes);
0357 dst += nbytes;
0358 src += nbytes;
0359 nbytes = 0;
0360 }
0361
0362 err = skcipher_walk_done(&walk, nbytes);
0363 }
0364
0365 return err;
0366 }
0367 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
0368
0369 static int ctr_crypt(struct skcipher_request *req)
0370 {
0371 return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
0372 sm4_aesni_avx_ctr_enc_blk8);
0373 }
0374
0375 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
0376 {
0377 .base = {
0378 .cra_name = "__ecb(sm4)",
0379 .cra_driver_name = "__ecb-sm4-aesni-avx",
0380 .cra_priority = 400,
0381 .cra_flags = CRYPTO_ALG_INTERNAL,
0382 .cra_blocksize = SM4_BLOCK_SIZE,
0383 .cra_ctxsize = sizeof(struct sm4_ctx),
0384 .cra_module = THIS_MODULE,
0385 },
0386 .min_keysize = SM4_KEY_SIZE,
0387 .max_keysize = SM4_KEY_SIZE,
0388 .walksize = 8 * SM4_BLOCK_SIZE,
0389 .setkey = sm4_skcipher_setkey,
0390 .encrypt = sm4_avx_ecb_encrypt,
0391 .decrypt = sm4_avx_ecb_decrypt,
0392 }, {
0393 .base = {
0394 .cra_name = "__cbc(sm4)",
0395 .cra_driver_name = "__cbc-sm4-aesni-avx",
0396 .cra_priority = 400,
0397 .cra_flags = CRYPTO_ALG_INTERNAL,
0398 .cra_blocksize = SM4_BLOCK_SIZE,
0399 .cra_ctxsize = sizeof(struct sm4_ctx),
0400 .cra_module = THIS_MODULE,
0401 },
0402 .min_keysize = SM4_KEY_SIZE,
0403 .max_keysize = SM4_KEY_SIZE,
0404 .ivsize = SM4_BLOCK_SIZE,
0405 .walksize = 8 * SM4_BLOCK_SIZE,
0406 .setkey = sm4_skcipher_setkey,
0407 .encrypt = sm4_cbc_encrypt,
0408 .decrypt = cbc_decrypt,
0409 }, {
0410 .base = {
0411 .cra_name = "__cfb(sm4)",
0412 .cra_driver_name = "__cfb-sm4-aesni-avx",
0413 .cra_priority = 400,
0414 .cra_flags = CRYPTO_ALG_INTERNAL,
0415 .cra_blocksize = 1,
0416 .cra_ctxsize = sizeof(struct sm4_ctx),
0417 .cra_module = THIS_MODULE,
0418 },
0419 .min_keysize = SM4_KEY_SIZE,
0420 .max_keysize = SM4_KEY_SIZE,
0421 .ivsize = SM4_BLOCK_SIZE,
0422 .chunksize = SM4_BLOCK_SIZE,
0423 .walksize = 8 * SM4_BLOCK_SIZE,
0424 .setkey = sm4_skcipher_setkey,
0425 .encrypt = sm4_cfb_encrypt,
0426 .decrypt = cfb_decrypt,
0427 }, {
0428 .base = {
0429 .cra_name = "__ctr(sm4)",
0430 .cra_driver_name = "__ctr-sm4-aesni-avx",
0431 .cra_priority = 400,
0432 .cra_flags = CRYPTO_ALG_INTERNAL,
0433 .cra_blocksize = 1,
0434 .cra_ctxsize = sizeof(struct sm4_ctx),
0435 .cra_module = THIS_MODULE,
0436 },
0437 .min_keysize = SM4_KEY_SIZE,
0438 .max_keysize = SM4_KEY_SIZE,
0439 .ivsize = SM4_BLOCK_SIZE,
0440 .chunksize = SM4_BLOCK_SIZE,
0441 .walksize = 8 * SM4_BLOCK_SIZE,
0442 .setkey = sm4_skcipher_setkey,
0443 .encrypt = ctr_crypt,
0444 .decrypt = ctr_crypt,
0445 }
0446 };
0447
0448 static struct simd_skcipher_alg *
0449 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
0450
0451 static int __init sm4_init(void)
0452 {
0453 const char *feature_name;
0454
0455 if (!boot_cpu_has(X86_FEATURE_AVX) ||
0456 !boot_cpu_has(X86_FEATURE_AES) ||
0457 !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
0458 pr_info("AVX or AES-NI instructions are not detected.\n");
0459 return -ENODEV;
0460 }
0461
0462 if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
0463 &feature_name)) {
0464 pr_info("CPU feature '%s' is not supported.\n", feature_name);
0465 return -ENODEV;
0466 }
0467
0468 return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
0469 ARRAY_SIZE(sm4_aesni_avx_skciphers),
0470 simd_sm4_aesni_avx_skciphers);
0471 }
0472
0473 static void __exit sm4_exit(void)
0474 {
0475 simd_unregister_skciphers(sm4_aesni_avx_skciphers,
0476 ARRAY_SIZE(sm4_aesni_avx_skciphers),
0477 simd_sm4_aesni_avx_skciphers);
0478 }
0479
0480 module_init(sm4_init);
0481 module_exit(sm4_exit);
0482
0483 MODULE_LICENSE("GPL v2");
0484 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
0485 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
0486 MODULE_ALIAS_CRYPTO("sm4");
0487 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");