Back to home page

OSCL-LXR

 
 

    


0001 /* SPDX-License-Identifier: GPL-2.0-or-later */
0002 /*
0003  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
0004  * as specified in
0005  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
0006  *
0007  * Copyright (c) 2021, Alibaba Group.
0008  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
0009  */
0010 
0011 #include <linux/module.h>
0012 #include <linux/crypto.h>
0013 #include <linux/kernel.h>
0014 #include <asm/simd.h>
0015 #include <crypto/internal/simd.h>
0016 #include <crypto/internal/skcipher.h>
0017 #include <crypto/sm4.h>
0018 #include "sm4-avx.h"
0019 
0020 #define SM4_CRYPT8_BLOCK_SIZE   (SM4_BLOCK_SIZE * 8)
0021 
0022 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
0023                 const u8 *src, int nblocks);
0024 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
0025                 const u8 *src, int nblocks);
0026 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
0027                 const u8 *src, u8 *iv);
0028 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
0029                 const u8 *src, u8 *iv);
0030 asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
0031                 const u8 *src, u8 *iv);
0032 
0033 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
0034             unsigned int key_len)
0035 {
0036     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0037 
0038     return sm4_expandkey(ctx, key, key_len);
0039 }
0040 
0041 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
0042 {
0043     struct skcipher_walk walk;
0044     unsigned int nbytes;
0045     int err;
0046 
0047     err = skcipher_walk_virt(&walk, req, false);
0048 
0049     while ((nbytes = walk.nbytes) > 0) {
0050         const u8 *src = walk.src.virt.addr;
0051         u8 *dst = walk.dst.virt.addr;
0052 
0053         kernel_fpu_begin();
0054         while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
0055             sm4_aesni_avx_crypt8(rkey, dst, src, 8);
0056             dst += SM4_CRYPT8_BLOCK_SIZE;
0057             src += SM4_CRYPT8_BLOCK_SIZE;
0058             nbytes -= SM4_CRYPT8_BLOCK_SIZE;
0059         }
0060         while (nbytes >= SM4_BLOCK_SIZE) {
0061             unsigned int nblocks = min(nbytes >> 4, 4u);
0062             sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
0063             dst += nblocks * SM4_BLOCK_SIZE;
0064             src += nblocks * SM4_BLOCK_SIZE;
0065             nbytes -= nblocks * SM4_BLOCK_SIZE;
0066         }
0067         kernel_fpu_end();
0068 
0069         err = skcipher_walk_done(&walk, nbytes);
0070     }
0071 
0072     return err;
0073 }
0074 
0075 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
0076 {
0077     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0078     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0079 
0080     return ecb_do_crypt(req, ctx->rkey_enc);
0081 }
0082 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
0083 
0084 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
0085 {
0086     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0087     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0088 
0089     return ecb_do_crypt(req, ctx->rkey_dec);
0090 }
0091 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
0092 
0093 int sm4_cbc_encrypt(struct skcipher_request *req)
0094 {
0095     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0096     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0097     struct skcipher_walk walk;
0098     unsigned int nbytes;
0099     int err;
0100 
0101     err = skcipher_walk_virt(&walk, req, false);
0102 
0103     while ((nbytes = walk.nbytes) > 0) {
0104         const u8 *iv = walk.iv;
0105         const u8 *src = walk.src.virt.addr;
0106         u8 *dst = walk.dst.virt.addr;
0107 
0108         while (nbytes >= SM4_BLOCK_SIZE) {
0109             crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
0110             sm4_crypt_block(ctx->rkey_enc, dst, dst);
0111             iv = dst;
0112             src += SM4_BLOCK_SIZE;
0113             dst += SM4_BLOCK_SIZE;
0114             nbytes -= SM4_BLOCK_SIZE;
0115         }
0116         if (iv != walk.iv)
0117             memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0118 
0119         err = skcipher_walk_done(&walk, nbytes);
0120     }
0121 
0122     return err;
0123 }
0124 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
0125 
0126 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
0127             unsigned int bsize, sm4_crypt_func func)
0128 {
0129     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0130     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0131     struct skcipher_walk walk;
0132     unsigned int nbytes;
0133     int err;
0134 
0135     err = skcipher_walk_virt(&walk, req, false);
0136 
0137     while ((nbytes = walk.nbytes) > 0) {
0138         const u8 *src = walk.src.virt.addr;
0139         u8 *dst = walk.dst.virt.addr;
0140 
0141         kernel_fpu_begin();
0142 
0143         while (nbytes >= bsize) {
0144             func(ctx->rkey_dec, dst, src, walk.iv);
0145             dst += bsize;
0146             src += bsize;
0147             nbytes -= bsize;
0148         }
0149 
0150         while (nbytes >= SM4_BLOCK_SIZE) {
0151             u8 keystream[SM4_BLOCK_SIZE * 8];
0152             u8 iv[SM4_BLOCK_SIZE];
0153             unsigned int nblocks = min(nbytes >> 4, 8u);
0154             int i;
0155 
0156             sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
0157                         src, nblocks);
0158 
0159             src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
0160             dst += (nblocks - 1) * SM4_BLOCK_SIZE;
0161             memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
0162 
0163             for (i = nblocks - 1; i > 0; i--) {
0164                 crypto_xor_cpy(dst, src,
0165                     &keystream[i * SM4_BLOCK_SIZE],
0166                     SM4_BLOCK_SIZE);
0167                 src -= SM4_BLOCK_SIZE;
0168                 dst -= SM4_BLOCK_SIZE;
0169             }
0170             crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
0171             memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0172             dst += nblocks * SM4_BLOCK_SIZE;
0173             src += (nblocks + 1) * SM4_BLOCK_SIZE;
0174             nbytes -= nblocks * SM4_BLOCK_SIZE;
0175         }
0176 
0177         kernel_fpu_end();
0178         err = skcipher_walk_done(&walk, nbytes);
0179     }
0180 
0181     return err;
0182 }
0183 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
0184 
0185 static int cbc_decrypt(struct skcipher_request *req)
0186 {
0187     return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
0188                 sm4_aesni_avx_cbc_dec_blk8);
0189 }
0190 
0191 int sm4_cfb_encrypt(struct skcipher_request *req)
0192 {
0193     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0194     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0195     struct skcipher_walk walk;
0196     unsigned int nbytes;
0197     int err;
0198 
0199     err = skcipher_walk_virt(&walk, req, false);
0200 
0201     while ((nbytes = walk.nbytes) > 0) {
0202         u8 keystream[SM4_BLOCK_SIZE];
0203         const u8 *iv = walk.iv;
0204         const u8 *src = walk.src.virt.addr;
0205         u8 *dst = walk.dst.virt.addr;
0206 
0207         while (nbytes >= SM4_BLOCK_SIZE) {
0208             sm4_crypt_block(ctx->rkey_enc, keystream, iv);
0209             crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
0210             iv = dst;
0211             src += SM4_BLOCK_SIZE;
0212             dst += SM4_BLOCK_SIZE;
0213             nbytes -= SM4_BLOCK_SIZE;
0214         }
0215         if (iv != walk.iv)
0216             memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0217 
0218         /* tail */
0219         if (walk.nbytes == walk.total && nbytes > 0) {
0220             sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0221             crypto_xor_cpy(dst, src, keystream, nbytes);
0222             nbytes = 0;
0223         }
0224 
0225         err = skcipher_walk_done(&walk, nbytes);
0226     }
0227 
0228     return err;
0229 }
0230 EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
0231 
0232 int sm4_avx_cfb_decrypt(struct skcipher_request *req,
0233             unsigned int bsize, sm4_crypt_func func)
0234 {
0235     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0236     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0237     struct skcipher_walk walk;
0238     unsigned int nbytes;
0239     int err;
0240 
0241     err = skcipher_walk_virt(&walk, req, false);
0242 
0243     while ((nbytes = walk.nbytes) > 0) {
0244         const u8 *src = walk.src.virt.addr;
0245         u8 *dst = walk.dst.virt.addr;
0246 
0247         kernel_fpu_begin();
0248 
0249         while (nbytes >= bsize) {
0250             func(ctx->rkey_enc, dst, src, walk.iv);
0251             dst += bsize;
0252             src += bsize;
0253             nbytes -= bsize;
0254         }
0255 
0256         while (nbytes >= SM4_BLOCK_SIZE) {
0257             u8 keystream[SM4_BLOCK_SIZE * 8];
0258             unsigned int nblocks = min(nbytes >> 4, 8u);
0259 
0260             memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
0261             if (nblocks > 1)
0262                 memcpy(&keystream[SM4_BLOCK_SIZE], src,
0263                     (nblocks - 1) * SM4_BLOCK_SIZE);
0264             memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
0265                 SM4_BLOCK_SIZE);
0266 
0267             sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
0268                         keystream, nblocks);
0269 
0270             crypto_xor_cpy(dst, src, keystream,
0271                     nblocks * SM4_BLOCK_SIZE);
0272             dst += nblocks * SM4_BLOCK_SIZE;
0273             src += nblocks * SM4_BLOCK_SIZE;
0274             nbytes -= nblocks * SM4_BLOCK_SIZE;
0275         }
0276 
0277         kernel_fpu_end();
0278 
0279         /* tail */
0280         if (walk.nbytes == walk.total && nbytes > 0) {
0281             u8 keystream[SM4_BLOCK_SIZE];
0282 
0283             sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0284             crypto_xor_cpy(dst, src, keystream, nbytes);
0285             nbytes = 0;
0286         }
0287 
0288         err = skcipher_walk_done(&walk, nbytes);
0289     }
0290 
0291     return err;
0292 }
0293 EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
0294 
0295 static int cfb_decrypt(struct skcipher_request *req)
0296 {
0297     return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
0298                 sm4_aesni_avx_cfb_dec_blk8);
0299 }
0300 
0301 int sm4_avx_ctr_crypt(struct skcipher_request *req,
0302             unsigned int bsize, sm4_crypt_func func)
0303 {
0304     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0305     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0306     struct skcipher_walk walk;
0307     unsigned int nbytes;
0308     int err;
0309 
0310     err = skcipher_walk_virt(&walk, req, false);
0311 
0312     while ((nbytes = walk.nbytes) > 0) {
0313         const u8 *src = walk.src.virt.addr;
0314         u8 *dst = walk.dst.virt.addr;
0315 
0316         kernel_fpu_begin();
0317 
0318         while (nbytes >= bsize) {
0319             func(ctx->rkey_enc, dst, src, walk.iv);
0320             dst += bsize;
0321             src += bsize;
0322             nbytes -= bsize;
0323         }
0324 
0325         while (nbytes >= SM4_BLOCK_SIZE) {
0326             u8 keystream[SM4_BLOCK_SIZE * 8];
0327             unsigned int nblocks = min(nbytes >> 4, 8u);
0328             int i;
0329 
0330             for (i = 0; i < nblocks; i++) {
0331                 memcpy(&keystream[i * SM4_BLOCK_SIZE],
0332                     walk.iv, SM4_BLOCK_SIZE);
0333                 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0334             }
0335             sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
0336                     keystream, nblocks);
0337 
0338             crypto_xor_cpy(dst, src, keystream,
0339                     nblocks * SM4_BLOCK_SIZE);
0340             dst += nblocks * SM4_BLOCK_SIZE;
0341             src += nblocks * SM4_BLOCK_SIZE;
0342             nbytes -= nblocks * SM4_BLOCK_SIZE;
0343         }
0344 
0345         kernel_fpu_end();
0346 
0347         /* tail */
0348         if (walk.nbytes == walk.total && nbytes > 0) {
0349             u8 keystream[SM4_BLOCK_SIZE];
0350 
0351             memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
0352             crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0353 
0354             sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
0355 
0356             crypto_xor_cpy(dst, src, keystream, nbytes);
0357             dst += nbytes;
0358             src += nbytes;
0359             nbytes = 0;
0360         }
0361 
0362         err = skcipher_walk_done(&walk, nbytes);
0363     }
0364 
0365     return err;
0366 }
0367 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
0368 
0369 static int ctr_crypt(struct skcipher_request *req)
0370 {
0371     return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
0372                 sm4_aesni_avx_ctr_enc_blk8);
0373 }
0374 
0375 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
0376     {
0377         .base = {
0378             .cra_name       = "__ecb(sm4)",
0379             .cra_driver_name    = "__ecb-sm4-aesni-avx",
0380             .cra_priority       = 400,
0381             .cra_flags      = CRYPTO_ALG_INTERNAL,
0382             .cra_blocksize      = SM4_BLOCK_SIZE,
0383             .cra_ctxsize        = sizeof(struct sm4_ctx),
0384             .cra_module     = THIS_MODULE,
0385         },
0386         .min_keysize    = SM4_KEY_SIZE,
0387         .max_keysize    = SM4_KEY_SIZE,
0388         .walksize   = 8 * SM4_BLOCK_SIZE,
0389         .setkey     = sm4_skcipher_setkey,
0390         .encrypt    = sm4_avx_ecb_encrypt,
0391         .decrypt    = sm4_avx_ecb_decrypt,
0392     }, {
0393         .base = {
0394             .cra_name       = "__cbc(sm4)",
0395             .cra_driver_name    = "__cbc-sm4-aesni-avx",
0396             .cra_priority       = 400,
0397             .cra_flags      = CRYPTO_ALG_INTERNAL,
0398             .cra_blocksize      = SM4_BLOCK_SIZE,
0399             .cra_ctxsize        = sizeof(struct sm4_ctx),
0400             .cra_module     = THIS_MODULE,
0401         },
0402         .min_keysize    = SM4_KEY_SIZE,
0403         .max_keysize    = SM4_KEY_SIZE,
0404         .ivsize     = SM4_BLOCK_SIZE,
0405         .walksize   = 8 * SM4_BLOCK_SIZE,
0406         .setkey     = sm4_skcipher_setkey,
0407         .encrypt    = sm4_cbc_encrypt,
0408         .decrypt    = cbc_decrypt,
0409     }, {
0410         .base = {
0411             .cra_name       = "__cfb(sm4)",
0412             .cra_driver_name    = "__cfb-sm4-aesni-avx",
0413             .cra_priority       = 400,
0414             .cra_flags      = CRYPTO_ALG_INTERNAL,
0415             .cra_blocksize      = 1,
0416             .cra_ctxsize        = sizeof(struct sm4_ctx),
0417             .cra_module     = THIS_MODULE,
0418         },
0419         .min_keysize    = SM4_KEY_SIZE,
0420         .max_keysize    = SM4_KEY_SIZE,
0421         .ivsize     = SM4_BLOCK_SIZE,
0422         .chunksize  = SM4_BLOCK_SIZE,
0423         .walksize   = 8 * SM4_BLOCK_SIZE,
0424         .setkey     = sm4_skcipher_setkey,
0425         .encrypt    = sm4_cfb_encrypt,
0426         .decrypt    = cfb_decrypt,
0427     }, {
0428         .base = {
0429             .cra_name       = "__ctr(sm4)",
0430             .cra_driver_name    = "__ctr-sm4-aesni-avx",
0431             .cra_priority       = 400,
0432             .cra_flags      = CRYPTO_ALG_INTERNAL,
0433             .cra_blocksize      = 1,
0434             .cra_ctxsize        = sizeof(struct sm4_ctx),
0435             .cra_module     = THIS_MODULE,
0436         },
0437         .min_keysize    = SM4_KEY_SIZE,
0438         .max_keysize    = SM4_KEY_SIZE,
0439         .ivsize     = SM4_BLOCK_SIZE,
0440         .chunksize  = SM4_BLOCK_SIZE,
0441         .walksize   = 8 * SM4_BLOCK_SIZE,
0442         .setkey     = sm4_skcipher_setkey,
0443         .encrypt    = ctr_crypt,
0444         .decrypt    = ctr_crypt,
0445     }
0446 };
0447 
0448 static struct simd_skcipher_alg *
0449 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
0450 
0451 static int __init sm4_init(void)
0452 {
0453     const char *feature_name;
0454 
0455     if (!boot_cpu_has(X86_FEATURE_AVX) ||
0456         !boot_cpu_has(X86_FEATURE_AES) ||
0457         !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
0458         pr_info("AVX or AES-NI instructions are not detected.\n");
0459         return -ENODEV;
0460     }
0461 
0462     if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
0463                 &feature_name)) {
0464         pr_info("CPU feature '%s' is not supported.\n", feature_name);
0465         return -ENODEV;
0466     }
0467 
0468     return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
0469                     ARRAY_SIZE(sm4_aesni_avx_skciphers),
0470                     simd_sm4_aesni_avx_skciphers);
0471 }
0472 
0473 static void __exit sm4_exit(void)
0474 {
0475     simd_unregister_skciphers(sm4_aesni_avx_skciphers,
0476                     ARRAY_SIZE(sm4_aesni_avx_skciphers),
0477                     simd_sm4_aesni_avx_skciphers);
0478 }
0479 
0480 module_init(sm4_init);
0481 module_exit(sm4_exit);
0482 
0483 MODULE_LICENSE("GPL v2");
0484 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
0485 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
0486 MODULE_ALIAS_CRYPTO("sm4");
0487 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");