Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * Cryptographic API.
0004  *
0005  * Copyright (c) 2017-present, Facebook, Inc.
0006  */
0007 #include <linux/crypto.h>
0008 #include <linux/init.h>
0009 #include <linux/interrupt.h>
0010 #include <linux/mm.h>
0011 #include <linux/module.h>
0012 #include <linux/net.h>
0013 #include <linux/vmalloc.h>
0014 #include <linux/zstd.h>
0015 #include <crypto/internal/scompress.h>
0016 
0017 
0018 #define ZSTD_DEF_LEVEL  3
0019 
0020 struct zstd_ctx {
0021     zstd_cctx *cctx;
0022     zstd_dctx *dctx;
0023     void *cwksp;
0024     void *dwksp;
0025 };
0026 
0027 static zstd_parameters zstd_params(void)
0028 {
0029     return zstd_get_params(ZSTD_DEF_LEVEL, 0);
0030 }
0031 
0032 static int zstd_comp_init(struct zstd_ctx *ctx)
0033 {
0034     int ret = 0;
0035     const zstd_parameters params = zstd_params();
0036     const size_t wksp_size = zstd_cctx_workspace_bound(&params.cParams);
0037 
0038     ctx->cwksp = vzalloc(wksp_size);
0039     if (!ctx->cwksp) {
0040         ret = -ENOMEM;
0041         goto out;
0042     }
0043 
0044     ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size);
0045     if (!ctx->cctx) {
0046         ret = -EINVAL;
0047         goto out_free;
0048     }
0049 out:
0050     return ret;
0051 out_free:
0052     vfree(ctx->cwksp);
0053     goto out;
0054 }
0055 
0056 static int zstd_decomp_init(struct zstd_ctx *ctx)
0057 {
0058     int ret = 0;
0059     const size_t wksp_size = zstd_dctx_workspace_bound();
0060 
0061     ctx->dwksp = vzalloc(wksp_size);
0062     if (!ctx->dwksp) {
0063         ret = -ENOMEM;
0064         goto out;
0065     }
0066 
0067     ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size);
0068     if (!ctx->dctx) {
0069         ret = -EINVAL;
0070         goto out_free;
0071     }
0072 out:
0073     return ret;
0074 out_free:
0075     vfree(ctx->dwksp);
0076     goto out;
0077 }
0078 
0079 static void zstd_comp_exit(struct zstd_ctx *ctx)
0080 {
0081     vfree(ctx->cwksp);
0082     ctx->cwksp = NULL;
0083     ctx->cctx = NULL;
0084 }
0085 
0086 static void zstd_decomp_exit(struct zstd_ctx *ctx)
0087 {
0088     vfree(ctx->dwksp);
0089     ctx->dwksp = NULL;
0090     ctx->dctx = NULL;
0091 }
0092 
0093 static int __zstd_init(void *ctx)
0094 {
0095     int ret;
0096 
0097     ret = zstd_comp_init(ctx);
0098     if (ret)
0099         return ret;
0100     ret = zstd_decomp_init(ctx);
0101     if (ret)
0102         zstd_comp_exit(ctx);
0103     return ret;
0104 }
0105 
0106 static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
0107 {
0108     int ret;
0109     struct zstd_ctx *ctx;
0110 
0111     ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
0112     if (!ctx)
0113         return ERR_PTR(-ENOMEM);
0114 
0115     ret = __zstd_init(ctx);
0116     if (ret) {
0117         kfree(ctx);
0118         return ERR_PTR(ret);
0119     }
0120 
0121     return ctx;
0122 }
0123 
0124 static int zstd_init(struct crypto_tfm *tfm)
0125 {
0126     struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0127 
0128     return __zstd_init(ctx);
0129 }
0130 
0131 static void __zstd_exit(void *ctx)
0132 {
0133     zstd_comp_exit(ctx);
0134     zstd_decomp_exit(ctx);
0135 }
0136 
0137 static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
0138 {
0139     __zstd_exit(ctx);
0140     kfree_sensitive(ctx);
0141 }
0142 
0143 static void zstd_exit(struct crypto_tfm *tfm)
0144 {
0145     struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0146 
0147     __zstd_exit(ctx);
0148 }
0149 
0150 static int __zstd_compress(const u8 *src, unsigned int slen,
0151                u8 *dst, unsigned int *dlen, void *ctx)
0152 {
0153     size_t out_len;
0154     struct zstd_ctx *zctx = ctx;
0155     const zstd_parameters params = zstd_params();
0156 
0157     out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, &params);
0158     if (zstd_is_error(out_len))
0159         return -EINVAL;
0160     *dlen = out_len;
0161     return 0;
0162 }
0163 
0164 static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
0165              unsigned int slen, u8 *dst, unsigned int *dlen)
0166 {
0167     struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0168 
0169     return __zstd_compress(src, slen, dst, dlen, ctx);
0170 }
0171 
0172 static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
0173               unsigned int slen, u8 *dst, unsigned int *dlen,
0174               void *ctx)
0175 {
0176     return __zstd_compress(src, slen, dst, dlen, ctx);
0177 }
0178 
0179 static int __zstd_decompress(const u8 *src, unsigned int slen,
0180                  u8 *dst, unsigned int *dlen, void *ctx)
0181 {
0182     size_t out_len;
0183     struct zstd_ctx *zctx = ctx;
0184 
0185     out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen);
0186     if (zstd_is_error(out_len))
0187         return -EINVAL;
0188     *dlen = out_len;
0189     return 0;
0190 }
0191 
0192 static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
0193                unsigned int slen, u8 *dst, unsigned int *dlen)
0194 {
0195     struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
0196 
0197     return __zstd_decompress(src, slen, dst, dlen, ctx);
0198 }
0199 
0200 static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
0201                 unsigned int slen, u8 *dst, unsigned int *dlen,
0202                 void *ctx)
0203 {
0204     return __zstd_decompress(src, slen, dst, dlen, ctx);
0205 }
0206 
0207 static struct crypto_alg alg = {
0208     .cra_name       = "zstd",
0209     .cra_driver_name    = "zstd-generic",
0210     .cra_flags      = CRYPTO_ALG_TYPE_COMPRESS,
0211     .cra_ctxsize        = sizeof(struct zstd_ctx),
0212     .cra_module     = THIS_MODULE,
0213     .cra_init       = zstd_init,
0214     .cra_exit       = zstd_exit,
0215     .cra_u          = { .compress = {
0216     .coa_compress       = zstd_compress,
0217     .coa_decompress     = zstd_decompress } }
0218 };
0219 
0220 static struct scomp_alg scomp = {
0221     .alloc_ctx      = zstd_alloc_ctx,
0222     .free_ctx       = zstd_free_ctx,
0223     .compress       = zstd_scompress,
0224     .decompress     = zstd_sdecompress,
0225     .base           = {
0226         .cra_name   = "zstd",
0227         .cra_driver_name = "zstd-scomp",
0228         .cra_module  = THIS_MODULE,
0229     }
0230 };
0231 
0232 static int __init zstd_mod_init(void)
0233 {
0234     int ret;
0235 
0236     ret = crypto_register_alg(&alg);
0237     if (ret)
0238         return ret;
0239 
0240     ret = crypto_register_scomp(&scomp);
0241     if (ret)
0242         crypto_unregister_alg(&alg);
0243 
0244     return ret;
0245 }
0246 
0247 static void __exit zstd_mod_fini(void)
0248 {
0249     crypto_unregister_alg(&alg);
0250     crypto_unregister_scomp(&scomp);
0251 }
0252 
0253 subsys_initcall(zstd_mod_init);
0254 module_exit(zstd_mod_fini);
0255 
0256 MODULE_LICENSE("GPL");
0257 MODULE_DESCRIPTION("Zstd Compression Algorithm");
0258 MODULE_ALIAS_CRYPTO("zstd");