0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068 #ifdef STATIC
0069 # define UNZSTD_PREBOOT
0070 # include "xxhash.c"
0071 # include "zstd/decompress_sources.h"
0072 #endif
0073
0074 #include <linux/decompress/mm.h>
0075 #include <linux/kernel.h>
0076 #include <linux/zstd.h>
0077
0078
0079 #define ZSTD_WINDOWSIZE_MAX (1 << ZSTD_WINDOWLOG_MAX)
0080
0081
0082
0083
0084
0085
0086 #define ZSTD_IOBUF_SIZE (1 << 17)
0087
0088 static int INIT handle_zstd_error(size_t ret, void (*error)(char *x))
0089 {
0090 const zstd_error_code err = zstd_get_error_code(ret);
0091
0092 if (!zstd_is_error(ret))
0093 return 0;
0094
0095
0096
0097
0098
0099 switch (err) {
0100 case ZSTD_error_memory_allocation:
0101 error("ZSTD decompressor ran out of memory");
0102 break;
0103 case ZSTD_error_prefix_unknown:
0104 error("Input is not in the ZSTD format (wrong magic bytes)");
0105 break;
0106 case ZSTD_error_dstSize_tooSmall:
0107 case ZSTD_error_corruption_detected:
0108 case ZSTD_error_checksum_wrong:
0109 error("ZSTD-compressed data is corrupt");
0110 break;
0111 default:
0112 error("ZSTD-compressed data is probably corrupt");
0113 break;
0114 }
0115 return -1;
0116 }
0117
0118
0119
0120
0121
0122
0123 static int INIT decompress_single(const u8 *in_buf, long in_len, u8 *out_buf,
0124 long out_len, long *in_pos,
0125 void (*error)(char *x))
0126 {
0127 const size_t wksp_size = zstd_dctx_workspace_bound();
0128 void *wksp = large_malloc(wksp_size);
0129 zstd_dctx *dctx = zstd_init_dctx(wksp, wksp_size);
0130 int err;
0131 size_t ret;
0132
0133 if (dctx == NULL) {
0134 error("Out of memory while allocating zstd_dctx");
0135 err = -1;
0136 goto out;
0137 }
0138
0139
0140
0141
0142 ret = zstd_find_frame_compressed_size(in_buf, in_len);
0143 err = handle_zstd_error(ret, error);
0144 if (err)
0145 goto out;
0146 in_len = (long)ret;
0147
0148 ret = zstd_decompress_dctx(dctx, out_buf, out_len, in_buf, in_len);
0149 err = handle_zstd_error(ret, error);
0150 if (err)
0151 goto out;
0152
0153 if (in_pos != NULL)
0154 *in_pos = in_len;
0155
0156 err = 0;
0157 out:
0158 if (wksp != NULL)
0159 large_free(wksp);
0160 return err;
0161 }
0162
0163 static int INIT __unzstd(unsigned char *in_buf, long in_len,
0164 long (*fill)(void*, unsigned long),
0165 long (*flush)(void*, unsigned long),
0166 unsigned char *out_buf, long out_len,
0167 long *in_pos,
0168 void (*error)(char *x))
0169 {
0170 zstd_in_buffer in;
0171 zstd_out_buffer out;
0172 zstd_frame_header header;
0173 void *in_allocated = NULL;
0174 void *out_allocated = NULL;
0175 void *wksp = NULL;
0176 size_t wksp_size;
0177 zstd_dstream *dstream;
0178 int err;
0179 size_t ret;
0180
0181
0182
0183
0184
0185
0186 if (out_len == 0)
0187 out_len = UINTPTR_MAX - (uintptr_t)out_buf;
0188
0189 if (fill == NULL && flush == NULL)
0190
0191
0192
0193
0194 return decompress_single(in_buf, in_len, out_buf, out_len,
0195 in_pos, error);
0196
0197
0198
0199
0200
0201
0202 if (in_buf == NULL) {
0203 in_allocated = large_malloc(ZSTD_IOBUF_SIZE);
0204 if (in_allocated == NULL) {
0205 error("Out of memory while allocating input buffer");
0206 err = -1;
0207 goto out;
0208 }
0209 in_buf = in_allocated;
0210 in_len = 0;
0211 }
0212
0213 if (fill != NULL)
0214 in_len = fill(in_buf, ZSTD_IOBUF_SIZE);
0215 if (in_len < 0) {
0216 error("ZSTD-compressed data is truncated");
0217 err = -1;
0218 goto out;
0219 }
0220
0221 in.src = in_buf;
0222 in.pos = 0;
0223 in.size = in_len;
0224
0225 if (flush != NULL) {
0226 out_allocated = large_malloc(ZSTD_IOBUF_SIZE);
0227 if (out_allocated == NULL) {
0228 error("Out of memory while allocating output buffer");
0229 err = -1;
0230 goto out;
0231 }
0232 out_buf = out_allocated;
0233 out_len = ZSTD_IOBUF_SIZE;
0234 }
0235
0236 out.dst = out_buf;
0237 out.pos = 0;
0238 out.size = out_len;
0239
0240
0241
0242
0243
0244
0245
0246
0247 ret = zstd_get_frame_header(&header, in.src, in.size);
0248 err = handle_zstd_error(ret, error);
0249 if (err)
0250 goto out;
0251 if (ret != 0) {
0252 error("ZSTD-compressed data has an incomplete frame header");
0253 err = -1;
0254 goto out;
0255 }
0256 if (header.windowSize > ZSTD_WINDOWSIZE_MAX) {
0257 error("ZSTD-compressed data has too large a window size");
0258 err = -1;
0259 goto out;
0260 }
0261
0262
0263
0264
0265
0266 wksp_size = zstd_dstream_workspace_bound(header.windowSize);
0267 wksp = large_malloc(wksp_size);
0268 dstream = zstd_init_dstream(header.windowSize, wksp, wksp_size);
0269 if (dstream == NULL) {
0270 error("Out of memory while allocating ZSTD_DStream");
0271 err = -1;
0272 goto out;
0273 }
0274
0275
0276
0277
0278
0279
0280
0281 if (in_pos != NULL)
0282 *in_pos = 0;
0283 do {
0284
0285
0286
0287
0288 if (in.pos == in.size) {
0289 if (in_pos != NULL)
0290 *in_pos += in.pos;
0291 in_len = fill ? fill(in_buf, ZSTD_IOBUF_SIZE) : -1;
0292 if (in_len < 0) {
0293 error("ZSTD-compressed data is truncated");
0294 err = -1;
0295 goto out;
0296 }
0297 in.pos = 0;
0298 in.size = in_len;
0299 }
0300
0301 ret = zstd_decompress_stream(dstream, &out, &in);
0302 err = handle_zstd_error(ret, error);
0303 if (err)
0304 goto out;
0305
0306 if (flush != NULL && out.pos > 0) {
0307 if (out.pos != flush(out.dst, out.pos)) {
0308 error("Failed to flush()");
0309 err = -1;
0310 goto out;
0311 }
0312 out.pos = 0;
0313 }
0314 } while (ret != 0);
0315
0316 if (in_pos != NULL)
0317 *in_pos += in.pos;
0318
0319 err = 0;
0320 out:
0321 if (in_allocated != NULL)
0322 large_free(in_allocated);
0323 if (out_allocated != NULL)
0324 large_free(out_allocated);
0325 if (wksp != NULL)
0326 large_free(wksp);
0327 return err;
0328 }
0329
0330 #ifndef UNZSTD_PREBOOT
0331 STATIC int INIT unzstd(unsigned char *buf, long len,
0332 long (*fill)(void*, unsigned long),
0333 long (*flush)(void*, unsigned long),
0334 unsigned char *out_buf,
0335 long *pos,
0336 void (*error)(char *x))
0337 {
0338 return __unzstd(buf, len, fill, flush, out_buf, 0, pos, error);
0339 }
0340 #else
0341 STATIC int INIT __decompress(unsigned char *buf, long len,
0342 long (*fill)(void*, unsigned long),
0343 long (*flush)(void*, unsigned long),
0344 unsigned char *out_buf, long out_len,
0345 long *pos,
0346 void (*error)(char *x))
0347 {
0348 return __unzstd(buf, len, fill, flush, out_buf, out_len, pos, error);
0349 }
0350 #endif