Back to home page

OSCL-LXR

 
 

    


0001 /* SPDX-License-Identifier: GPL-2.0-or-later */
0002 /*
0003  * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
0004  * as specified in
0005  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
0006  *
0007  * Copyright (C) 2022, Alibaba Group.
0008  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
0009  */
0010 
0011 #include <linux/module.h>
0012 #include <linux/crypto.h>
0013 #include <linux/kernel.h>
0014 #include <linux/cpufeature.h>
0015 #include <asm/neon.h>
0016 #include <asm/simd.h>
0017 #include <crypto/internal/simd.h>
0018 #include <crypto/internal/skcipher.h>
0019 #include <crypto/sm4.h>
0020 
0021 #define BYTES2BLKS(nbytes)  ((nbytes) >> 4)
0022 
0023 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
0024                   const u32 *fk, const u32 *ck);
0025 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
0026 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
0027                  unsigned int nblks);
0028 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
0029                    u8 *iv, unsigned int nblks);
0030 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
0031                    u8 *iv, unsigned int nblks);
0032 asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
0033                    u8 *iv, unsigned int nblks);
0034 asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
0035                    u8 *iv, unsigned int nblks);
0036 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
0037                    u8 *iv, unsigned int nblks);
0038 
0039 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
0040               unsigned int key_len)
0041 {
0042     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0043 
0044     if (key_len != SM4_KEY_SIZE)
0045         return -EINVAL;
0046 
0047     sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
0048               crypto_sm4_fk, crypto_sm4_ck);
0049     return 0;
0050 }
0051 
0052 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
0053 {
0054     struct skcipher_walk walk;
0055     unsigned int nbytes;
0056     int err;
0057 
0058     err = skcipher_walk_virt(&walk, req, false);
0059 
0060     while ((nbytes = walk.nbytes) > 0) {
0061         const u8 *src = walk.src.virt.addr;
0062         u8 *dst = walk.dst.virt.addr;
0063         unsigned int nblks;
0064 
0065         kernel_neon_begin();
0066 
0067         nblks = BYTES2BLKS(nbytes);
0068         if (nblks) {
0069             sm4_ce_crypt(rkey, dst, src, nblks);
0070             nbytes -= nblks * SM4_BLOCK_SIZE;
0071         }
0072 
0073         kernel_neon_end();
0074 
0075         err = skcipher_walk_done(&walk, nbytes);
0076     }
0077 
0078     return err;
0079 }
0080 
0081 static int sm4_ecb_encrypt(struct skcipher_request *req)
0082 {
0083     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0084     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0085 
0086     return sm4_ecb_do_crypt(req, ctx->rkey_enc);
0087 }
0088 
0089 static int sm4_ecb_decrypt(struct skcipher_request *req)
0090 {
0091     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0092     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0093 
0094     return sm4_ecb_do_crypt(req, ctx->rkey_dec);
0095 }
0096 
0097 static int sm4_cbc_encrypt(struct skcipher_request *req)
0098 {
0099     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0100     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0101     struct skcipher_walk walk;
0102     unsigned int nbytes;
0103     int err;
0104 
0105     err = skcipher_walk_virt(&walk, req, false);
0106 
0107     while ((nbytes = walk.nbytes) > 0) {
0108         const u8 *src = walk.src.virt.addr;
0109         u8 *dst = walk.dst.virt.addr;
0110         unsigned int nblks;
0111 
0112         kernel_neon_begin();
0113 
0114         nblks = BYTES2BLKS(nbytes);
0115         if (nblks) {
0116             sm4_ce_cbc_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
0117             nbytes -= nblks * SM4_BLOCK_SIZE;
0118         }
0119 
0120         kernel_neon_end();
0121 
0122         err = skcipher_walk_done(&walk, nbytes);
0123     }
0124 
0125     return err;
0126 }
0127 
0128 static int sm4_cbc_decrypt(struct skcipher_request *req)
0129 {
0130     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0131     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0132     struct skcipher_walk walk;
0133     unsigned int nbytes;
0134     int err;
0135 
0136     err = skcipher_walk_virt(&walk, req, false);
0137 
0138     while ((nbytes = walk.nbytes) > 0) {
0139         const u8 *src = walk.src.virt.addr;
0140         u8 *dst = walk.dst.virt.addr;
0141         unsigned int nblks;
0142 
0143         kernel_neon_begin();
0144 
0145         nblks = BYTES2BLKS(nbytes);
0146         if (nblks) {
0147             sm4_ce_cbc_dec(ctx->rkey_dec, dst, src, walk.iv, nblks);
0148             nbytes -= nblks * SM4_BLOCK_SIZE;
0149         }
0150 
0151         kernel_neon_end();
0152 
0153         err = skcipher_walk_done(&walk, nbytes);
0154     }
0155 
0156     return err;
0157 }
0158 
0159 static int sm4_cfb_encrypt(struct skcipher_request *req)
0160 {
0161     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0162     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0163     struct skcipher_walk walk;
0164     unsigned int nbytes;
0165     int err;
0166 
0167     err = skcipher_walk_virt(&walk, req, false);
0168 
0169     while ((nbytes = walk.nbytes) > 0) {
0170         const u8 *src = walk.src.virt.addr;
0171         u8 *dst = walk.dst.virt.addr;
0172         unsigned int nblks;
0173 
0174         kernel_neon_begin();
0175 
0176         nblks = BYTES2BLKS(nbytes);
0177         if (nblks) {
0178             sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
0179             dst += nblks * SM4_BLOCK_SIZE;
0180             src += nblks * SM4_BLOCK_SIZE;
0181             nbytes -= nblks * SM4_BLOCK_SIZE;
0182         }
0183 
0184         /* tail */
0185         if (walk.nbytes == walk.total && nbytes > 0) {
0186             u8 keystream[SM4_BLOCK_SIZE];
0187 
0188             sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0189             crypto_xor_cpy(dst, src, keystream, nbytes);
0190             nbytes = 0;
0191         }
0192 
0193         kernel_neon_end();
0194 
0195         err = skcipher_walk_done(&walk, nbytes);
0196     }
0197 
0198     return err;
0199 }
0200 
0201 static int sm4_cfb_decrypt(struct skcipher_request *req)
0202 {
0203     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0204     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0205     struct skcipher_walk walk;
0206     unsigned int nbytes;
0207     int err;
0208 
0209     err = skcipher_walk_virt(&walk, req, false);
0210 
0211     while ((nbytes = walk.nbytes) > 0) {
0212         const u8 *src = walk.src.virt.addr;
0213         u8 *dst = walk.dst.virt.addr;
0214         unsigned int nblks;
0215 
0216         kernel_neon_begin();
0217 
0218         nblks = BYTES2BLKS(nbytes);
0219         if (nblks) {
0220             sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
0221             dst += nblks * SM4_BLOCK_SIZE;
0222             src += nblks * SM4_BLOCK_SIZE;
0223             nbytes -= nblks * SM4_BLOCK_SIZE;
0224         }
0225 
0226         /* tail */
0227         if (walk.nbytes == walk.total && nbytes > 0) {
0228             u8 keystream[SM4_BLOCK_SIZE];
0229 
0230             sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0231             crypto_xor_cpy(dst, src, keystream, nbytes);
0232             nbytes = 0;
0233         }
0234 
0235         kernel_neon_end();
0236 
0237         err = skcipher_walk_done(&walk, nbytes);
0238     }
0239 
0240     return err;
0241 }
0242 
0243 static int sm4_ctr_crypt(struct skcipher_request *req)
0244 {
0245     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0246     struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
0247     struct skcipher_walk walk;
0248     unsigned int nbytes;
0249     int err;
0250 
0251     err = skcipher_walk_virt(&walk, req, false);
0252 
0253     while ((nbytes = walk.nbytes) > 0) {
0254         const u8 *src = walk.src.virt.addr;
0255         u8 *dst = walk.dst.virt.addr;
0256         unsigned int nblks;
0257 
0258         kernel_neon_begin();
0259 
0260         nblks = BYTES2BLKS(nbytes);
0261         if (nblks) {
0262             sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
0263             dst += nblks * SM4_BLOCK_SIZE;
0264             src += nblks * SM4_BLOCK_SIZE;
0265             nbytes -= nblks * SM4_BLOCK_SIZE;
0266         }
0267 
0268         /* tail */
0269         if (walk.nbytes == walk.total && nbytes > 0) {
0270             u8 keystream[SM4_BLOCK_SIZE];
0271 
0272             sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
0273             crypto_inc(walk.iv, SM4_BLOCK_SIZE);
0274             crypto_xor_cpy(dst, src, keystream, nbytes);
0275             nbytes = 0;
0276         }
0277 
0278         kernel_neon_end();
0279 
0280         err = skcipher_walk_done(&walk, nbytes);
0281     }
0282 
0283     return err;
0284 }
0285 
0286 static struct skcipher_alg sm4_algs[] = {
0287     {
0288         .base = {
0289             .cra_name       = "ecb(sm4)",
0290             .cra_driver_name    = "ecb-sm4-ce",
0291             .cra_priority       = 400,
0292             .cra_blocksize      = SM4_BLOCK_SIZE,
0293             .cra_ctxsize        = sizeof(struct sm4_ctx),
0294             .cra_module     = THIS_MODULE,
0295         },
0296         .min_keysize    = SM4_KEY_SIZE,
0297         .max_keysize    = SM4_KEY_SIZE,
0298         .setkey     = sm4_setkey,
0299         .encrypt    = sm4_ecb_encrypt,
0300         .decrypt    = sm4_ecb_decrypt,
0301     }, {
0302         .base = {
0303             .cra_name       = "cbc(sm4)",
0304             .cra_driver_name    = "cbc-sm4-ce",
0305             .cra_priority       = 400,
0306             .cra_blocksize      = SM4_BLOCK_SIZE,
0307             .cra_ctxsize        = sizeof(struct sm4_ctx),
0308             .cra_module     = THIS_MODULE,
0309         },
0310         .min_keysize    = SM4_KEY_SIZE,
0311         .max_keysize    = SM4_KEY_SIZE,
0312         .ivsize     = SM4_BLOCK_SIZE,
0313         .setkey     = sm4_setkey,
0314         .encrypt    = sm4_cbc_encrypt,
0315         .decrypt    = sm4_cbc_decrypt,
0316     }, {
0317         .base = {
0318             .cra_name       = "cfb(sm4)",
0319             .cra_driver_name    = "cfb-sm4-ce",
0320             .cra_priority       = 400,
0321             .cra_blocksize      = 1,
0322             .cra_ctxsize        = sizeof(struct sm4_ctx),
0323             .cra_module     = THIS_MODULE,
0324         },
0325         .min_keysize    = SM4_KEY_SIZE,
0326         .max_keysize    = SM4_KEY_SIZE,
0327         .ivsize     = SM4_BLOCK_SIZE,
0328         .chunksize  = SM4_BLOCK_SIZE,
0329         .setkey     = sm4_setkey,
0330         .encrypt    = sm4_cfb_encrypt,
0331         .decrypt    = sm4_cfb_decrypt,
0332     }, {
0333         .base = {
0334             .cra_name       = "ctr(sm4)",
0335             .cra_driver_name    = "ctr-sm4-ce",
0336             .cra_priority       = 400,
0337             .cra_blocksize      = 1,
0338             .cra_ctxsize        = sizeof(struct sm4_ctx),
0339             .cra_module     = THIS_MODULE,
0340         },
0341         .min_keysize    = SM4_KEY_SIZE,
0342         .max_keysize    = SM4_KEY_SIZE,
0343         .ivsize     = SM4_BLOCK_SIZE,
0344         .chunksize  = SM4_BLOCK_SIZE,
0345         .setkey     = sm4_setkey,
0346         .encrypt    = sm4_ctr_crypt,
0347         .decrypt    = sm4_ctr_crypt,
0348     }
0349 };
0350 
0351 static int __init sm4_init(void)
0352 {
0353     return crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
0354 }
0355 
0356 static void __exit sm4_exit(void)
0357 {
0358     crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
0359 }
0360 
0361 module_cpu_feature_match(SM4, sm4_init);
0362 module_exit(sm4_exit);
0363 
0364 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR using ARMv8 Crypto Extensions");
0365 MODULE_ALIAS_CRYPTO("sm4-ce");
0366 MODULE_ALIAS_CRYPTO("sm4");
0367 MODULE_ALIAS_CRYPTO("ecb(sm4)");
0368 MODULE_ALIAS_CRYPTO("cbc(sm4)");
0369 MODULE_ALIAS_CRYPTO("cfb(sm4)");
0370 MODULE_ALIAS_CRYPTO("ctr(sm4)");
0371 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
0372 MODULE_LICENSE("GPL v2");