0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <crypto/algapi.h>
0010 #include <crypto/internal/chacha.h>
0011 #include <crypto/internal/simd.h>
0012 #include <crypto/internal/skcipher.h>
0013 #include <linux/kernel.h>
0014 #include <linux/module.h>
0015 #include <linux/sizes.h>
0016 #include <asm/simd.h>
0017
0018 asmlinkage void chacha_block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
0019 unsigned int len, int nrounds);
0020 asmlinkage void chacha_4block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
0021 unsigned int len, int nrounds);
0022 asmlinkage void hchacha_block_ssse3(const u32 *state, u32 *out, int nrounds);
0023
0024 asmlinkage void chacha_2block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
0025 unsigned int len, int nrounds);
0026 asmlinkage void chacha_4block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
0027 unsigned int len, int nrounds);
0028 asmlinkage void chacha_8block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
0029 unsigned int len, int nrounds);
0030
0031 asmlinkage void chacha_2block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
0032 unsigned int len, int nrounds);
0033 asmlinkage void chacha_4block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
0034 unsigned int len, int nrounds);
0035 asmlinkage void chacha_8block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
0036 unsigned int len, int nrounds);
0037
0038 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_simd);
0039 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx2);
0040 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx512vl);
0041
0042 static unsigned int chacha_advance(unsigned int len, unsigned int maxblocks)
0043 {
0044 len = min(len, maxblocks * CHACHA_BLOCK_SIZE);
0045 return round_up(len, CHACHA_BLOCK_SIZE) / CHACHA_BLOCK_SIZE;
0046 }
0047
0048 static void chacha_dosimd(u32 *state, u8 *dst, const u8 *src,
0049 unsigned int bytes, int nrounds)
0050 {
0051 if (IS_ENABLED(CONFIG_AS_AVX512) &&
0052 static_branch_likely(&chacha_use_avx512vl)) {
0053 while (bytes >= CHACHA_BLOCK_SIZE * 8) {
0054 chacha_8block_xor_avx512vl(state, dst, src, bytes,
0055 nrounds);
0056 bytes -= CHACHA_BLOCK_SIZE * 8;
0057 src += CHACHA_BLOCK_SIZE * 8;
0058 dst += CHACHA_BLOCK_SIZE * 8;
0059 state[12] += 8;
0060 }
0061 if (bytes > CHACHA_BLOCK_SIZE * 4) {
0062 chacha_8block_xor_avx512vl(state, dst, src, bytes,
0063 nrounds);
0064 state[12] += chacha_advance(bytes, 8);
0065 return;
0066 }
0067 if (bytes > CHACHA_BLOCK_SIZE * 2) {
0068 chacha_4block_xor_avx512vl(state, dst, src, bytes,
0069 nrounds);
0070 state[12] += chacha_advance(bytes, 4);
0071 return;
0072 }
0073 if (bytes) {
0074 chacha_2block_xor_avx512vl(state, dst, src, bytes,
0075 nrounds);
0076 state[12] += chacha_advance(bytes, 2);
0077 return;
0078 }
0079 }
0080
0081 if (static_branch_likely(&chacha_use_avx2)) {
0082 while (bytes >= CHACHA_BLOCK_SIZE * 8) {
0083 chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
0084 bytes -= CHACHA_BLOCK_SIZE * 8;
0085 src += CHACHA_BLOCK_SIZE * 8;
0086 dst += CHACHA_BLOCK_SIZE * 8;
0087 state[12] += 8;
0088 }
0089 if (bytes > CHACHA_BLOCK_SIZE * 4) {
0090 chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
0091 state[12] += chacha_advance(bytes, 8);
0092 return;
0093 }
0094 if (bytes > CHACHA_BLOCK_SIZE * 2) {
0095 chacha_4block_xor_avx2(state, dst, src, bytes, nrounds);
0096 state[12] += chacha_advance(bytes, 4);
0097 return;
0098 }
0099 if (bytes > CHACHA_BLOCK_SIZE) {
0100 chacha_2block_xor_avx2(state, dst, src, bytes, nrounds);
0101 state[12] += chacha_advance(bytes, 2);
0102 return;
0103 }
0104 }
0105
0106 while (bytes >= CHACHA_BLOCK_SIZE * 4) {
0107 chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
0108 bytes -= CHACHA_BLOCK_SIZE * 4;
0109 src += CHACHA_BLOCK_SIZE * 4;
0110 dst += CHACHA_BLOCK_SIZE * 4;
0111 state[12] += 4;
0112 }
0113 if (bytes > CHACHA_BLOCK_SIZE) {
0114 chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
0115 state[12] += chacha_advance(bytes, 4);
0116 return;
0117 }
0118 if (bytes) {
0119 chacha_block_xor_ssse3(state, dst, src, bytes, nrounds);
0120 state[12]++;
0121 }
0122 }
0123
0124 void hchacha_block_arch(const u32 *state, u32 *stream, int nrounds)
0125 {
0126 if (!static_branch_likely(&chacha_use_simd) || !crypto_simd_usable()) {
0127 hchacha_block_generic(state, stream, nrounds);
0128 } else {
0129 kernel_fpu_begin();
0130 hchacha_block_ssse3(state, stream, nrounds);
0131 kernel_fpu_end();
0132 }
0133 }
0134 EXPORT_SYMBOL(hchacha_block_arch);
0135
0136 void chacha_init_arch(u32 *state, const u32 *key, const u8 *iv)
0137 {
0138 chacha_init_generic(state, key, iv);
0139 }
0140 EXPORT_SYMBOL(chacha_init_arch);
0141
0142 void chacha_crypt_arch(u32 *state, u8 *dst, const u8 *src, unsigned int bytes,
0143 int nrounds)
0144 {
0145 if (!static_branch_likely(&chacha_use_simd) || !crypto_simd_usable() ||
0146 bytes <= CHACHA_BLOCK_SIZE)
0147 return chacha_crypt_generic(state, dst, src, bytes, nrounds);
0148
0149 do {
0150 unsigned int todo = min_t(unsigned int, bytes, SZ_4K);
0151
0152 kernel_fpu_begin();
0153 chacha_dosimd(state, dst, src, todo, nrounds);
0154 kernel_fpu_end();
0155
0156 bytes -= todo;
0157 src += todo;
0158 dst += todo;
0159 } while (bytes);
0160 }
0161 EXPORT_SYMBOL(chacha_crypt_arch);
0162
0163 static int chacha_simd_stream_xor(struct skcipher_request *req,
0164 const struct chacha_ctx *ctx, const u8 *iv)
0165 {
0166 u32 state[CHACHA_STATE_WORDS] __aligned(8);
0167 struct skcipher_walk walk;
0168 int err;
0169
0170 err = skcipher_walk_virt(&walk, req, false);
0171
0172 chacha_init_generic(state, ctx->key, iv);
0173
0174 while (walk.nbytes > 0) {
0175 unsigned int nbytes = walk.nbytes;
0176
0177 if (nbytes < walk.total)
0178 nbytes = round_down(nbytes, walk.stride);
0179
0180 if (!static_branch_likely(&chacha_use_simd) ||
0181 !crypto_simd_usable()) {
0182 chacha_crypt_generic(state, walk.dst.virt.addr,
0183 walk.src.virt.addr, nbytes,
0184 ctx->nrounds);
0185 } else {
0186 kernel_fpu_begin();
0187 chacha_dosimd(state, walk.dst.virt.addr,
0188 walk.src.virt.addr, nbytes,
0189 ctx->nrounds);
0190 kernel_fpu_end();
0191 }
0192 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
0193 }
0194
0195 return err;
0196 }
0197
0198 static int chacha_simd(struct skcipher_request *req)
0199 {
0200 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0201 struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
0202
0203 return chacha_simd_stream_xor(req, ctx, req->iv);
0204 }
0205
0206 static int xchacha_simd(struct skcipher_request *req)
0207 {
0208 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
0209 struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
0210 u32 state[CHACHA_STATE_WORDS] __aligned(8);
0211 struct chacha_ctx subctx;
0212 u8 real_iv[16];
0213
0214 chacha_init_generic(state, ctx->key, req->iv);
0215
0216 if (req->cryptlen > CHACHA_BLOCK_SIZE && crypto_simd_usable()) {
0217 kernel_fpu_begin();
0218 hchacha_block_ssse3(state, subctx.key, ctx->nrounds);
0219 kernel_fpu_end();
0220 } else {
0221 hchacha_block_generic(state, subctx.key, ctx->nrounds);
0222 }
0223 subctx.nrounds = ctx->nrounds;
0224
0225 memcpy(&real_iv[0], req->iv + 24, 8);
0226 memcpy(&real_iv[8], req->iv + 16, 8);
0227 return chacha_simd_stream_xor(req, &subctx, real_iv);
0228 }
0229
0230 static struct skcipher_alg algs[] = {
0231 {
0232 .base.cra_name = "chacha20",
0233 .base.cra_driver_name = "chacha20-simd",
0234 .base.cra_priority = 300,
0235 .base.cra_blocksize = 1,
0236 .base.cra_ctxsize = sizeof(struct chacha_ctx),
0237 .base.cra_module = THIS_MODULE,
0238
0239 .min_keysize = CHACHA_KEY_SIZE,
0240 .max_keysize = CHACHA_KEY_SIZE,
0241 .ivsize = CHACHA_IV_SIZE,
0242 .chunksize = CHACHA_BLOCK_SIZE,
0243 .setkey = chacha20_setkey,
0244 .encrypt = chacha_simd,
0245 .decrypt = chacha_simd,
0246 }, {
0247 .base.cra_name = "xchacha20",
0248 .base.cra_driver_name = "xchacha20-simd",
0249 .base.cra_priority = 300,
0250 .base.cra_blocksize = 1,
0251 .base.cra_ctxsize = sizeof(struct chacha_ctx),
0252 .base.cra_module = THIS_MODULE,
0253
0254 .min_keysize = CHACHA_KEY_SIZE,
0255 .max_keysize = CHACHA_KEY_SIZE,
0256 .ivsize = XCHACHA_IV_SIZE,
0257 .chunksize = CHACHA_BLOCK_SIZE,
0258 .setkey = chacha20_setkey,
0259 .encrypt = xchacha_simd,
0260 .decrypt = xchacha_simd,
0261 }, {
0262 .base.cra_name = "xchacha12",
0263 .base.cra_driver_name = "xchacha12-simd",
0264 .base.cra_priority = 300,
0265 .base.cra_blocksize = 1,
0266 .base.cra_ctxsize = sizeof(struct chacha_ctx),
0267 .base.cra_module = THIS_MODULE,
0268
0269 .min_keysize = CHACHA_KEY_SIZE,
0270 .max_keysize = CHACHA_KEY_SIZE,
0271 .ivsize = XCHACHA_IV_SIZE,
0272 .chunksize = CHACHA_BLOCK_SIZE,
0273 .setkey = chacha12_setkey,
0274 .encrypt = xchacha_simd,
0275 .decrypt = xchacha_simd,
0276 },
0277 };
0278
0279 static int __init chacha_simd_mod_init(void)
0280 {
0281 if (!boot_cpu_has(X86_FEATURE_SSSE3))
0282 return 0;
0283
0284 static_branch_enable(&chacha_use_simd);
0285
0286 if (boot_cpu_has(X86_FEATURE_AVX) &&
0287 boot_cpu_has(X86_FEATURE_AVX2) &&
0288 cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL)) {
0289 static_branch_enable(&chacha_use_avx2);
0290
0291 if (IS_ENABLED(CONFIG_AS_AVX512) &&
0292 boot_cpu_has(X86_FEATURE_AVX512VL) &&
0293 boot_cpu_has(X86_FEATURE_AVX512BW))
0294 static_branch_enable(&chacha_use_avx512vl);
0295 }
0296 return IS_REACHABLE(CONFIG_CRYPTO_SKCIPHER) ?
0297 crypto_register_skciphers(algs, ARRAY_SIZE(algs)) : 0;
0298 }
0299
0300 static void __exit chacha_simd_mod_fini(void)
0301 {
0302 if (IS_REACHABLE(CONFIG_CRYPTO_SKCIPHER) && boot_cpu_has(X86_FEATURE_SSSE3))
0303 crypto_unregister_skciphers(algs, ARRAY_SIZE(algs));
0304 }
0305
0306 module_init(chacha_simd_mod_init);
0307 module_exit(chacha_simd_mod_fini);
0308
0309 MODULE_LICENSE("GPL");
0310 MODULE_AUTHOR("Martin Willi <martin@strongswan.org>");
0311 MODULE_DESCRIPTION("ChaCha and XChaCha stream ciphers (x64 SIMD accelerated)");
0312 MODULE_ALIAS_CRYPTO("chacha20");
0313 MODULE_ALIAS_CRYPTO("chacha20-simd");
0314 MODULE_ALIAS_CRYPTO("xchacha20");
0315 MODULE_ALIAS_CRYPTO("xchacha20-simd");
0316 MODULE_ALIAS_CRYPTO("xchacha12");
0317 MODULE_ALIAS_CRYPTO("xchacha12-simd");