0001
0002
0003
0004
0005
0006
0007 #include <linux/crypto.h>
0008 #include <linux/init.h>
0009 #include <linux/interrupt.h>
0010 #include <linux/mm.h>
0011 #include <linux/module.h>
0012 #include <linux/net.h>
0013 #include <linux/vmalloc.h>
0014 #include <linux/zstd.h>
0015 #include <crypto/internal/scompress.h>
0016
0017
0018 #define ZSTD_DEF_LEVEL 3
0019
0020 struct zstd_ctx {
0021 zstd_cctx *cctx;
0022 zstd_dctx *dctx;
0023 void *cwksp;
0024 void *dwksp;
0025 };
0026
0027 static zstd_parameters zstd_params(void)
0028 {
0029 return zstd_get_params(ZSTD_DEF_LEVEL, 0);
0030 }
0031
0032 static int zstd_comp_init(struct zstd_ctx *ctx)
0033 {
0034 int ret = 0;
0035 const zstd_parameters params = zstd_params();
0036 const size_t wksp_size = zstd_cctx_workspace_bound(¶ms.cParams);
0037
0038 ctx->cwksp = vzalloc(wksp_size);
0039 if (!ctx->cwksp) {
0040 ret = -ENOMEM;
0041 goto out;
0042 }
0043
0044 ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size);
0045 if (!ctx->cctx) {
0046 ret = -EINVAL;
0047 goto out_free;
0048 }
0049 out:
0050 return ret;
0051 out_free:
0052 vfree(ctx->cwksp);
0053 goto out;
0054 }
0055
0056 static int zstd_decomp_init(struct zstd_ctx *ctx)
0057 {
0058 int ret = 0;
0059 const size_t wksp_size = zstd_dctx_workspace_bound();
0060
0061 ctx->dwksp = vzalloc(wksp_size);
0062 if (!ctx->dwksp) {
0063 ret = -ENOMEM;
0064 goto out;
0065 }
0066
0067 ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size);
0068 if (!ctx->dctx) {
0069 ret = -EINVAL;
0070 goto out_free;
0071 }
0072 out:
0073 return ret;
0074 out_free:
0075 vfree(ctx->dwksp);
0076 goto out;
0077 }
0078
0079 static void zstd_comp_exit(struct zstd_ctx *ctx)
0080 {
0081 vfree(ctx->cwksp);
0082 ctx->cwksp = NULL;
0083 ctx->cctx = NULL;
0084 }
0085
0086 static void zstd_decomp_exit(struct zstd_ctx *ctx)
0087 {
0088 vfree(ctx->dwksp);
0089 ctx->dwksp = NULL;
0090 ctx->dctx = NULL;
0091 }
0092
0093 static int __zstd_init(void *ctx)
0094 {
0095 int ret;
0096
0097 ret = zstd_comp_init(ctx);
0098 if (ret)
0099 return ret;
0100 ret = zstd_decomp_init(ctx);
0101 if (ret)
0102 zstd_comp_exit(ctx);
0103 return ret;
0104 }
0105
0106 static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
0107 {
0108 int ret;
0109 struct zstd_ctx *ctx;
0110
0111 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
0112 if (!ctx)
0113 return ERR_PTR(-ENOMEM);
0114
0115 ret = __zstd_init(ctx);
0116 if (ret) {
0117 kfree(ctx);
0118 return ERR_PTR(ret);
0119 }
0120
0121 return ctx;
0122 }
0123
0124 static int zstd_init(struct crypto_tfm *tfm)
0125 {
0126 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0127
0128 return __zstd_init(ctx);
0129 }
0130
0131 static void __zstd_exit(void *ctx)
0132 {
0133 zstd_comp_exit(ctx);
0134 zstd_decomp_exit(ctx);
0135 }
0136
0137 static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
0138 {
0139 __zstd_exit(ctx);
0140 kfree_sensitive(ctx);
0141 }
0142
0143 static void zstd_exit(struct crypto_tfm *tfm)
0144 {
0145 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0146
0147 __zstd_exit(ctx);
0148 }
0149
0150 static int __zstd_compress(const u8 *src, unsigned int slen,
0151 u8 *dst, unsigned int *dlen, void *ctx)
0152 {
0153 size_t out_len;
0154 struct zstd_ctx *zctx = ctx;
0155 const zstd_parameters params = zstd_params();
0156
0157 out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, ¶ms);
0158 if (zstd_is_error(out_len))
0159 return -EINVAL;
0160 *dlen = out_len;
0161 return 0;
0162 }
0163
0164 static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
0165 unsigned int slen, u8 *dst, unsigned int *dlen)
0166 {
0167 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0168
0169 return __zstd_compress(src, slen, dst, dlen, ctx);
0170 }
0171
0172 static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
0173 unsigned int slen, u8 *dst, unsigned int *dlen,
0174 void *ctx)
0175 {
0176 return __zstd_compress(src, slen, dst, dlen, ctx);
0177 }
0178
0179 static int __zstd_decompress(const u8 *src, unsigned int slen,
0180 u8 *dst, unsigned int *dlen, void *ctx)
0181 {
0182 size_t out_len;
0183 struct zstd_ctx *zctx = ctx;
0184
0185 out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen);
0186 if (zstd_is_error(out_len))
0187 return -EINVAL;
0188 *dlen = out_len;
0189 return 0;
0190 }
0191
0192 static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
0193 unsigned int slen, u8 *dst, unsigned int *dlen)
0194 {
0195 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0196
0197 return __zstd_decompress(src, slen, dst, dlen, ctx);
0198 }
0199
0200 static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
0201 unsigned int slen, u8 *dst, unsigned int *dlen,
0202 void *ctx)
0203 {
0204 return __zstd_decompress(src, slen, dst, dlen, ctx);
0205 }
0206
0207 static struct crypto_alg alg = {
0208 .cra_name = "zstd",
0209 .cra_driver_name = "zstd-generic",
0210 .cra_flags = CRYPTO_ALG_TYPE_COMPRESS,
0211 .cra_ctxsize = sizeof(struct zstd_ctx),
0212 .cra_module = THIS_MODULE,
0213 .cra_init = zstd_init,
0214 .cra_exit = zstd_exit,
0215 .cra_u = { .compress = {
0216 .coa_compress = zstd_compress,
0217 .coa_decompress = zstd_decompress } }
0218 };
0219
0220 static struct scomp_alg scomp = {
0221 .alloc_ctx = zstd_alloc_ctx,
0222 .free_ctx = zstd_free_ctx,
0223 .compress = zstd_scompress,
0224 .decompress = zstd_sdecompress,
0225 .base = {
0226 .cra_name = "zstd",
0227 .cra_driver_name = "zstd-scomp",
0228 .cra_module = THIS_MODULE,
0229 }
0230 };
0231
0232 static int __init zstd_mod_init(void)
0233 {
0234 int ret;
0235
0236 ret = crypto_register_alg(&alg);
0237 if (ret)
0238 return ret;
0239
0240 ret = crypto_register_scomp(&scomp);
0241 if (ret)
0242 crypto_unregister_alg(&alg);
0243
0244 return ret;
0245 }
0246
0247 static void __exit zstd_mod_fini(void)
0248 {
0249 crypto_unregister_alg(&alg);
0250 crypto_unregister_scomp(&scomp);
0251 }
0252
0253 subsys_initcall(zstd_mod_init);
0254 module_exit(zstd_mod_fini);
0255
0256 MODULE_LICENSE("GPL");
0257 MODULE_DESCRIPTION("Zstd Compression Algorithm");
0258 MODULE_ALIAS_CRYPTO("zstd");