Back to home page

OSCL-LXR

 
 

    


0001 /* SPDX-License-Identifier: GPL-2.0-or-later */
0002 /*
0003  * SM4 Cipher Algorithm, using ARMv8 NEON
0004  * as specified in
0005  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
0006  *
0007  * Copyright (C) 2022, Alibaba Group.
0008  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
0009  */
0010 
0011 #include <linux/module.h>
0012 #include <linux/crypto.h>
0013 #include <linux/kernel.h>
0014 #include <linux/cpufeature.h>
0015 #include <asm/neon.h>
0016 #include <asm/simd.h>
0017 #include <crypto/internal/simd.h>
0018 #include <crypto/internal/skcipher.h>
0019 #include <crypto/sm4.h>
0020 
0021 #define BYTES2BLKS(nbytes)  ((nbytes) >> 4)
0022 #define BYTES2BLK8(nbytes)  (((nbytes) >> 4) & ~(8 - 1))
0023 
0024 asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src,
0025                       unsigned int nblks);
0026 asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src,
0027                     unsigned int nblks);
0028 asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
0029                       u8 *iv, unsigned int nblks);
0030 asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
0031                       u8 *iv, unsigned int nblks);
0032 asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src,
0033                       u8 *iv, unsigned int nblks);
0034 
0035 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
0036               unsigned int key_len)
0037 {
0038     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0039 
0040     return sm4_expandkey(ctx, key, key_len);
0041 }
0042 
0043 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
0044 {
0045     struct skcipher_walk walk;
0046     unsigned int nbytes;
0047     int err;
0048 
0049     err = skcipher_walk_virt(&walk, req, false);
0050 
0051     while ((nbytes = walk.nbytes) > 0) {
0052         const u8 *src = walk.src.virt.addr;
0053         u8 *dst = walk.dst.virt.addr;
0054         unsigned int nblks;
0055 
0056         kernel_neon_begin();
0057 
0058         nblks = BYTES2BLK8(nbytes);
0059         if (nblks) {
0060             sm4_neon_crypt_blk8(rkey, dst, src, nblks);
0061             dst += nblks * SM4_BLOCK_SIZE;
0062             src += nblks * SM4_BLOCK_SIZE;
0063             nbytes -= nblks * SM4_BLOCK_SIZE;
0064         }
0065 
0066         nblks = BYTES2BLKS(nbytes);
0067         if (nblks) {
0068             sm4_neon_crypt_blk1_8(rkey, dst, src, nblks);
0069             nbytes -= nblks * SM4_BLOCK_SIZE;
0070         }
0071 
0072         kernel_neon_end();
0073 
0074         err = skcipher_walk_done(&walk, nbytes);
0075     }
0076 
0077     return err;
0078 }
0079 
0080 static int sm4_ecb_encrypt(struct skcipher_request *req)
0081 {
0082     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0083     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0084 
0085     return sm4_ecb_do_crypt(req, ctx->rkey_enc);
0086 }
0087 
0088 static int sm4_ecb_decrypt(struct skcipher_request *req)
0089 {
0090     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0091     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0092 
0093     return sm4_ecb_do_crypt(req, ctx->rkey_dec);
0094 }
0095 
0096 static int sm4_cbc_encrypt(struct skcipher_request *req)
0097 {
0098     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0099     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0100     struct skcipher_walk walk;
0101     unsigned int nbytes;
0102     int err;
0103 
0104     err = skcipher_walk_virt(&walk, req, false);
0105 
0106     while ((nbytes = walk.nbytes) > 0) {
0107         const u8 *iv = walk.iv;
0108         const u8 *src = walk.src.virt.addr;
0109         u8 *dst = walk.dst.virt.addr;
0110 
0111         while (nbytes >= SM4_BLOCK_SIZE) {
0112             crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
0113             sm4_crypt_block(ctx->rkey_enc, dst, dst);
0114             iv = dst;
0115             src += SM4_BLOCK_SIZE;
0116             dst += SM4_BLOCK_SIZE;
0117             nbytes -= SM4_BLOCK_SIZE;
0118         }
0119         if (iv != walk.iv)
0120             memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0121 
0122         err = skcipher_walk_done(&walk, nbytes);
0123     }
0124 
0125     return err;
0126 }
0127 
0128 static int sm4_cbc_decrypt(struct skcipher_request *req)
0129 {
0130     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0131     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0132     struct skcipher_walk walk;
0133     unsigned int nbytes;
0134     int err;
0135 
0136     err = skcipher_walk_virt(&walk, req, false);
0137 
0138     while ((nbytes = walk.nbytes) > 0) {
0139         const u8 *src = walk.src.virt.addr;
0140         u8 *dst = walk.dst.virt.addr;
0141         unsigned int nblks;
0142 
0143         kernel_neon_begin();
0144 
0145         nblks = BYTES2BLK8(nbytes);
0146         if (nblks) {
0147             sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src,
0148                     walk.iv, nblks);
0149             dst += nblks * SM4_BLOCK_SIZE;
0150             src += nblks * SM4_BLOCK_SIZE;
0151             nbytes -= nblks * SM4_BLOCK_SIZE;
0152         }
0153 
0154         nblks = BYTES2BLKS(nbytes);
0155         if (nblks) {
0156             u8 keystream[SM4_BLOCK_SIZE * 8];
0157             u8 iv[SM4_BLOCK_SIZE];
0158             int i;
0159 
0160             sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream,
0161                     src, nblks);
0162 
0163             src += ((int)nblks - 2) * SM4_BLOCK_SIZE;
0164             dst += (nblks - 1) * SM4_BLOCK_SIZE;
0165             memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
0166 
0167             for (i = nblks - 1; i > 0; i--) {
0168                 crypto_xor_cpy(dst, src,
0169                     &keystream[i * SM4_BLOCK_SIZE],
0170                     SM4_BLOCK_SIZE);
0171                 src -= SM4_BLOCK_SIZE;
0172                 dst -= SM4_BLOCK_SIZE;
0173             }
0174             crypto_xor_cpy(dst, walk.iv,
0175                     keystream, SM4_BLOCK_SIZE);
0176             memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0177             nbytes -= nblks * SM4_BLOCK_SIZE;
0178         }
0179 
0180         kernel_neon_end();
0181 
0182         err = skcipher_walk_done(&walk, nbytes);
0183     }
0184 
0185     return err;
0186 }
0187 
0188 static int sm4_cfb_encrypt(struct skcipher_request *req)
0189 {
0190     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0191     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0192     struct skcipher_walk walk;
0193     unsigned int nbytes;
0194     int err;
0195 
0196     err = skcipher_walk_virt(&walk, req, false);
0197 
0198     while ((nbytes = walk.nbytes) > 0) {
0199         u8 keystream[SM4_BLOCK_SIZE];
0200         const u8 *iv = walk.iv;
0201         const u8 *src = walk.src.virt.addr;
0202         u8 *dst = walk.dst.virt.addr;
0203 
0204         while (nbytes >= SM4_BLOCK_SIZE) {
0205             sm4_crypt_block(ctx->rkey_enc, keystream, iv);
0206             crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
0207             iv = dst;
0208             src += SM4_BLOCK_SIZE;
0209             dst += SM4_BLOCK_SIZE;
0210             nbytes -= SM4_BLOCK_SIZE;
0211         }
0212         if (iv != walk.iv)
0213             memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
0214 
0215         /* tail */
0216         if (walk.nbytes == walk.total && nbytes > 0) {
0217             sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0218             crypto_xor_cpy(dst, src, keystream, nbytes);
0219             nbytes = 0;
0220         }
0221 
0222         err = skcipher_walk_done(&walk, nbytes);
0223     }
0224 
0225     return err;
0226 }
0227 
0228 static int sm4_cfb_decrypt(struct skcipher_request *req)
0229 {
0230     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0231     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0232     struct skcipher_walk walk;
0233     unsigned int nbytes;
0234     int err;
0235 
0236     err = skcipher_walk_virt(&walk, req, false);
0237 
0238     while ((nbytes = walk.nbytes) > 0) {
0239         const u8 *src = walk.src.virt.addr;
0240         u8 *dst = walk.dst.virt.addr;
0241         unsigned int nblks;
0242 
0243         kernel_neon_begin();
0244 
0245         nblks = BYTES2BLK8(nbytes);
0246         if (nblks) {
0247             sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src,
0248                     walk.iv, nblks);
0249             dst += nblks * SM4_BLOCK_SIZE;
0250             src += nblks * SM4_BLOCK_SIZE;
0251             nbytes -= nblks * SM4_BLOCK_SIZE;
0252         }
0253 
0254         nblks = BYTES2BLKS(nbytes);
0255         if (nblks) {
0256             u8 keystream[SM4_BLOCK_SIZE * 8];
0257 
0258             memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
0259             if (nblks > 1)
0260                 memcpy(&keystream[SM4_BLOCK_SIZE], src,
0261                     (nblks - 1) * SM4_BLOCK_SIZE);
0262             memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE,
0263                 SM4_BLOCK_SIZE);
0264 
0265             sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
0266                     keystream, nblks);
0267 
0268             crypto_xor_cpy(dst, src, keystream,
0269                     nblks * SM4_BLOCK_SIZE);
0270             dst += nblks * SM4_BLOCK_SIZE;
0271             src += nblks * SM4_BLOCK_SIZE;
0272             nbytes -= nblks * SM4_BLOCK_SIZE;
0273         }
0274 
0275         kernel_neon_end();
0276 
0277         /* tail */
0278         if (walk.nbytes == walk.total && nbytes > 0) {
0279             u8 keystream[SM4_BLOCK_SIZE];
0280 
0281             sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0282             crypto_xor_cpy(dst, src, keystream, nbytes);
0283             nbytes = 0;
0284         }
0285 
0286         err = skcipher_walk_done(&walk, nbytes);
0287     }
0288 
0289     return err;
0290 }
0291 
0292 static int sm4_ctr_crypt(struct skcipher_request *req)
0293 {
0294     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0295     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0296     struct skcipher_walk walk;
0297     unsigned int nbytes;
0298     int err;
0299 
0300     err = skcipher_walk_virt(&walk, req, false);
0301 
0302     while ((nbytes = walk.nbytes) > 0) {
0303         const u8 *src = walk.src.virt.addr;
0304         u8 *dst = walk.dst.virt.addr;
0305         unsigned int nblks;
0306 
0307         kernel_neon_begin();
0308 
0309         nblks = BYTES2BLK8(nbytes);
0310         if (nblks) {
0311             sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src,
0312                     walk.iv, nblks);
0313             dst += nblks * SM4_BLOCK_SIZE;
0314             src += nblks * SM4_BLOCK_SIZE;
0315             nbytes -= nblks * SM4_BLOCK_SIZE;
0316         }
0317 
0318         nblks = BYTES2BLKS(nbytes);
0319         if (nblks) {
0320             u8 keystream[SM4_BLOCK_SIZE * 8];
0321             int i;
0322 
0323             for (i = 0; i < nblks; i++) {
0324                 memcpy(&keystream[i * SM4_BLOCK_SIZE],
0325                     walk.iv, SM4_BLOCK_SIZE);
0326                 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0327             }
0328             sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
0329                     keystream, nblks);
0330 
0331             crypto_xor_cpy(dst, src, keystream,
0332                     nblks * SM4_BLOCK_SIZE);
0333             dst += nblks * SM4_BLOCK_SIZE;
0334             src += nblks * SM4_BLOCK_SIZE;
0335             nbytes -= nblks * SM4_BLOCK_SIZE;
0336         }
0337 
0338         kernel_neon_end();
0339 
0340         /* tail */
0341         if (walk.nbytes == walk.total && nbytes > 0) {
0342             u8 keystream[SM4_BLOCK_SIZE];
0343 
0344             sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0345             crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0346             crypto_xor_cpy(dst, src, keystream, nbytes);
0347             nbytes = 0;
0348         }
0349 
0350         err = skcipher_walk_done(&walk, nbytes);
0351     }
0352 
0353     return err;
0354 }
0355 
0356 static struct skcipher_alg sm4_algs[] = {
0357     {
0358         .base = {
0359             .cra_name       = "ecb(sm4)",
0360             .cra_driver_name    = "ecb-sm4-neon",
0361             .cra_priority       = 200,
0362             .cra_blocksize      = SM4_BLOCK_SIZE,
0363             .cra_ctxsize        = sizeof(struct sm4_ctx),
0364             .cra_module     = THIS_MODULE,
0365         },
0366         .min_keysize    = SM4_KEY_SIZE,
0367         .max_keysize    = SM4_KEY_SIZE,
0368         .setkey     = sm4_setkey,
0369         .encrypt    = sm4_ecb_encrypt,
0370         .decrypt    = sm4_ecb_decrypt,
0371     }, {
0372         .base = {
0373             .cra_name       = "cbc(sm4)",
0374             .cra_driver_name    = "cbc-sm4-neon",
0375             .cra_priority       = 200,
0376             .cra_blocksize      = SM4_BLOCK_SIZE,
0377             .cra_ctxsize        = sizeof(struct sm4_ctx),
0378             .cra_module     = THIS_MODULE,
0379         },
0380         .min_keysize    = SM4_KEY_SIZE,
0381         .max_keysize    = SM4_KEY_SIZE,
0382         .ivsize     = SM4_BLOCK_SIZE,
0383         .setkey     = sm4_setkey,
0384         .encrypt    = sm4_cbc_encrypt,
0385         .decrypt    = sm4_cbc_decrypt,
0386     }, {
0387         .base = {
0388             .cra_name       = "cfb(sm4)",
0389             .cra_driver_name    = "cfb-sm4-neon",
0390             .cra_priority       = 200,
0391             .cra_blocksize      = 1,
0392             .cra_ctxsize        = sizeof(struct sm4_ctx),
0393             .cra_module     = THIS_MODULE,
0394         },
0395         .min_keysize    = SM4_KEY_SIZE,
0396         .max_keysize    = SM4_KEY_SIZE,
0397         .ivsize     = SM4_BLOCK_SIZE,
0398         .chunksize  = SM4_BLOCK_SIZE,
0399         .setkey     = sm4_setkey,
0400         .encrypt    = sm4_cfb_encrypt,
0401         .decrypt    = sm4_cfb_decrypt,
0402     }, {
0403         .base = {
0404             .cra_name       = "ctr(sm4)",
0405             .cra_driver_name    = "ctr-sm4-neon",
0406             .cra_priority       = 200,
0407             .cra_blocksize      = 1,
0408             .cra_ctxsize        = sizeof(struct sm4_ctx),
0409             .cra_module     = THIS_MODULE,
0410         },
0411         .min_keysize    = SM4_KEY_SIZE,
0412         .max_keysize    = SM4_KEY_SIZE,
0413         .ivsize     = SM4_BLOCK_SIZE,
0414         .chunksize  = SM4_BLOCK_SIZE,
0415         .setkey     = sm4_setkey,
0416         .encrypt    = sm4_ctr_crypt,
0417         .decrypt    = sm4_ctr_crypt,
0418     }
0419 };
0420 
0421 static int __init sm4_init(void)
0422 {
0423     return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
0424 }
0425 
0426 static void __exit sm4_exit(void)
0427 {
0428     crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
0429 }
0430 
0431 module_init(sm4_init);
0432 module_exit(sm4_exit);
0433 
0434 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 NEON");
0435 MODULE_ALIAS_CRYPTO("sm4-neon");
0436 MODULE_ALIAS_CRYPTO("sm4");
0437 MODULE_ALIAS_CRYPTO("ecb(sm4)");
0438 MODULE_ALIAS_CRYPTO("cbc(sm4)");
0439 MODULE_ALIAS_CRYPTO("cfb(sm4)");
0440 MODULE_ALIAS_CRYPTO("ctr(sm4)");
0441 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
0442 MODULE_LICENSE("GPL v2");