Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 
0003 /*
0004  * Important notes about in-place decompression
0005  *
0006  * At least on x86, the kernel is decompressed in place: the compressed data
0007  * is placed to the end of the output buffer, and the decompressor overwrites
0008  * most of the compressed data. There must be enough safety margin to
0009  * guarantee that the write position is always behind the read position.
0010  *
0011  * The safety margin for ZSTD with a 128 KB block size is calculated below.
0012  * Note that the margin with ZSTD is bigger than with GZIP or XZ!
0013  *
0014  * The worst case for in-place decompression is that the beginning of
0015  * the file is compressed extremely well, and the rest of the file is
0016  * uncompressible. Thus, we must look for worst-case expansion when the
0017  * compressor is encoding uncompressible data.
0018  *
0019  * The structure of the .zst file in case of a compressed kernel is as follows.
0020  * Maximum sizes (as bytes) of the fields are in parenthesis.
0021  *
0022  *    Frame Header: (18)
0023  *    Blocks: (N)
0024  *    Checksum: (4)
0025  *
0026  * The frame header and checksum overhead is at most 22 bytes.
0027  *
0028  * ZSTD stores the data in blocks. Each block has a header whose size is
0029  * a 3 bytes. After the block header, there is up to 128 KB of payload.
0030  * The maximum uncompressed size of the payload is 128 KB. The minimum
0031  * uncompressed size of the payload is never less than the payload size
0032  * (excluding the block header).
0033  *
0034  * The assumption, that the uncompressed size of the payload is never
0035  * smaller than the payload itself, is valid only when talking about
0036  * the payload as a whole. It is possible that the payload has parts where
0037  * the decompressor consumes more input than it produces output. Calculating
0038  * the worst case for this would be tricky. Instead of trying to do that,
0039  * let's simply make sure that the decompressor never overwrites any bytes
0040  * of the payload which it is currently reading.
0041  *
0042  * Now we have enough information to calculate the safety margin. We need
0043  *   - 22 bytes for the .zst file format headers;
0044  *   - 3 bytes per every 128 KiB of uncompressed size (one block header per
0045  *     block); and
0046  *   - 128 KiB (biggest possible zstd block size) to make sure that the
0047  *     decompressor never overwrites anything from the block it is currently
0048  *     reading.
0049  *
0050  * We get the following formula:
0051  *
0052  *    safety_margin = 22 + uncompressed_size * 3 / 131072 + 131072
0053  *                 <= 22 + (uncompressed_size >> 15) + 131072
0054  */
0055 
0056 /*
0057  * Preboot environments #include "path/to/decompress_unzstd.c".
0058  * All of the source files we depend on must be #included.
0059  * zstd's only source dependency is xxhash, which has no source
0060  * dependencies.
0061  *
0062  * When UNZSTD_PREBOOT is defined we declare __decompress(), which is
0063  * used for kernel decompression, instead of unzstd().
0064  *
0065  * Define __DISABLE_EXPORTS in preboot environments to prevent symbols
0066  * from xxhash and zstd from being exported by the EXPORT_SYMBOL macro.
0067  */
0068 #ifdef STATIC
0069 # define UNZSTD_PREBOOT
0070 # include "xxhash.c"
0071 # include "zstd/decompress_sources.h"
0072 #endif
0073 
0074 #include <linux/decompress/mm.h>
0075 #include <linux/kernel.h>
0076 #include <linux/zstd.h>
0077 
0078 /* 128MB is the maximum window size supported by zstd. */
0079 #define ZSTD_WINDOWSIZE_MAX (1 << ZSTD_WINDOWLOG_MAX)
0080 /*
0081  * Size of the input and output buffers in multi-call mode.
0082  * Pick a larger size because it isn't used during kernel decompression,
0083  * since that is single pass, and we have to allocate a large buffer for
0084  * zstd's window anyway. The larger size speeds up initramfs decompression.
0085  */
0086 #define ZSTD_IOBUF_SIZE     (1 << 17)
0087 
0088 static int INIT handle_zstd_error(size_t ret, void (*error)(char *x))
0089 {
0090     const zstd_error_code err = zstd_get_error_code(ret);
0091 
0092     if (!zstd_is_error(ret))
0093         return 0;
0094 
0095     /*
0096      * zstd_get_error_name() cannot be used because error takes a char *
0097      * not a const char *
0098      */
0099     switch (err) {
0100     case ZSTD_error_memory_allocation:
0101         error("ZSTD decompressor ran out of memory");
0102         break;
0103     case ZSTD_error_prefix_unknown:
0104         error("Input is not in the ZSTD format (wrong magic bytes)");
0105         break;
0106     case ZSTD_error_dstSize_tooSmall:
0107     case ZSTD_error_corruption_detected:
0108     case ZSTD_error_checksum_wrong:
0109         error("ZSTD-compressed data is corrupt");
0110         break;
0111     default:
0112         error("ZSTD-compressed data is probably corrupt");
0113         break;
0114     }
0115     return -1;
0116 }
0117 
0118 /*
0119  * Handle the case where we have the entire input and output in one segment.
0120  * We can allocate less memory (no circular buffer for the sliding window),
0121  * and avoid some memcpy() calls.
0122  */
0123 static int INIT decompress_single(const u8 *in_buf, long in_len, u8 *out_buf,
0124                   long out_len, long *in_pos,
0125                   void (*error)(char *x))
0126 {
0127     const size_t wksp_size = zstd_dctx_workspace_bound();
0128     void *wksp = large_malloc(wksp_size);
0129     zstd_dctx *dctx = zstd_init_dctx(wksp, wksp_size);
0130     int err;
0131     size_t ret;
0132 
0133     if (dctx == NULL) {
0134         error("Out of memory while allocating zstd_dctx");
0135         err = -1;
0136         goto out;
0137     }
0138     /*
0139      * Find out how large the frame actually is, there may be junk at
0140      * the end of the frame that zstd_decompress_dctx() can't handle.
0141      */
0142     ret = zstd_find_frame_compressed_size(in_buf, in_len);
0143     err = handle_zstd_error(ret, error);
0144     if (err)
0145         goto out;
0146     in_len = (long)ret;
0147 
0148     ret = zstd_decompress_dctx(dctx, out_buf, out_len, in_buf, in_len);
0149     err = handle_zstd_error(ret, error);
0150     if (err)
0151         goto out;
0152 
0153     if (in_pos != NULL)
0154         *in_pos = in_len;
0155 
0156     err = 0;
0157 out:
0158     if (wksp != NULL)
0159         large_free(wksp);
0160     return err;
0161 }
0162 
0163 static int INIT __unzstd(unsigned char *in_buf, long in_len,
0164              long (*fill)(void*, unsigned long),
0165              long (*flush)(void*, unsigned long),
0166              unsigned char *out_buf, long out_len,
0167              long *in_pos,
0168              void (*error)(char *x))
0169 {
0170     zstd_in_buffer in;
0171     zstd_out_buffer out;
0172     zstd_frame_header header;
0173     void *in_allocated = NULL;
0174     void *out_allocated = NULL;
0175     void *wksp = NULL;
0176     size_t wksp_size;
0177     zstd_dstream *dstream;
0178     int err;
0179     size_t ret;
0180 
0181     /*
0182      * ZSTD decompression code won't be happy if the buffer size is so big
0183      * that its end address overflows. When the size is not provided, make
0184      * it as big as possible without having the end address overflow.
0185      */
0186     if (out_len == 0)
0187         out_len = UINTPTR_MAX - (uintptr_t)out_buf;
0188 
0189     if (fill == NULL && flush == NULL)
0190         /*
0191          * We can decompress faster and with less memory when we have a
0192          * single chunk.
0193          */
0194         return decompress_single(in_buf, in_len, out_buf, out_len,
0195                      in_pos, error);
0196 
0197     /*
0198      * If in_buf is not provided, we must be using fill(), so allocate
0199      * a large enough buffer. If it is provided, it must be at least
0200      * ZSTD_IOBUF_SIZE large.
0201      */
0202     if (in_buf == NULL) {
0203         in_allocated = large_malloc(ZSTD_IOBUF_SIZE);
0204         if (in_allocated == NULL) {
0205             error("Out of memory while allocating input buffer");
0206             err = -1;
0207             goto out;
0208         }
0209         in_buf = in_allocated;
0210         in_len = 0;
0211     }
0212     /* Read the first chunk, since we need to decode the frame header. */
0213     if (fill != NULL)
0214         in_len = fill(in_buf, ZSTD_IOBUF_SIZE);
0215     if (in_len < 0) {
0216         error("ZSTD-compressed data is truncated");
0217         err = -1;
0218         goto out;
0219     }
0220     /* Set the first non-empty input buffer. */
0221     in.src = in_buf;
0222     in.pos = 0;
0223     in.size = in_len;
0224     /* Allocate the output buffer if we are using flush(). */
0225     if (flush != NULL) {
0226         out_allocated = large_malloc(ZSTD_IOBUF_SIZE);
0227         if (out_allocated == NULL) {
0228             error("Out of memory while allocating output buffer");
0229             err = -1;
0230             goto out;
0231         }
0232         out_buf = out_allocated;
0233         out_len = ZSTD_IOBUF_SIZE;
0234     }
0235     /* Set the output buffer. */
0236     out.dst = out_buf;
0237     out.pos = 0;
0238     out.size = out_len;
0239 
0240     /*
0241      * We need to know the window size to allocate the zstd_dstream.
0242      * Since we are streaming, we need to allocate a buffer for the sliding
0243      * window. The window size varies from 1 KB to ZSTD_WINDOWSIZE_MAX
0244      * (8 MB), so it is important to use the actual value so as not to
0245      * waste memory when it is smaller.
0246      */
0247     ret = zstd_get_frame_header(&header, in.src, in.size);
0248     err = handle_zstd_error(ret, error);
0249     if (err)
0250         goto out;
0251     if (ret != 0) {
0252         error("ZSTD-compressed data has an incomplete frame header");
0253         err = -1;
0254         goto out;
0255     }
0256     if (header.windowSize > ZSTD_WINDOWSIZE_MAX) {
0257         error("ZSTD-compressed data has too large a window size");
0258         err = -1;
0259         goto out;
0260     }
0261 
0262     /*
0263      * Allocate the zstd_dstream now that we know how much memory is
0264      * required.
0265      */
0266     wksp_size = zstd_dstream_workspace_bound(header.windowSize);
0267     wksp = large_malloc(wksp_size);
0268     dstream = zstd_init_dstream(header.windowSize, wksp, wksp_size);
0269     if (dstream == NULL) {
0270         error("Out of memory while allocating ZSTD_DStream");
0271         err = -1;
0272         goto out;
0273     }
0274 
0275     /*
0276      * Decompression loop:
0277      * Read more data if necessary (error if no more data can be read).
0278      * Call the decompression function, which returns 0 when finished.
0279      * Flush any data produced if using flush().
0280      */
0281     if (in_pos != NULL)
0282         *in_pos = 0;
0283     do {
0284         /*
0285          * If we need to reload data, either we have fill() and can
0286          * try to get more data, or we don't and the input is truncated.
0287          */
0288         if (in.pos == in.size) {
0289             if (in_pos != NULL)
0290                 *in_pos += in.pos;
0291             in_len = fill ? fill(in_buf, ZSTD_IOBUF_SIZE) : -1;
0292             if (in_len < 0) {
0293                 error("ZSTD-compressed data is truncated");
0294                 err = -1;
0295                 goto out;
0296             }
0297             in.pos = 0;
0298             in.size = in_len;
0299         }
0300         /* Returns zero when the frame is complete. */
0301         ret = zstd_decompress_stream(dstream, &out, &in);
0302         err = handle_zstd_error(ret, error);
0303         if (err)
0304             goto out;
0305         /* Flush all of the data produced if using flush(). */
0306         if (flush != NULL && out.pos > 0) {
0307             if (out.pos != flush(out.dst, out.pos)) {
0308                 error("Failed to flush()");
0309                 err = -1;
0310                 goto out;
0311             }
0312             out.pos = 0;
0313         }
0314     } while (ret != 0);
0315 
0316     if (in_pos != NULL)
0317         *in_pos += in.pos;
0318 
0319     err = 0;
0320 out:
0321     if (in_allocated != NULL)
0322         large_free(in_allocated);
0323     if (out_allocated != NULL)
0324         large_free(out_allocated);
0325     if (wksp != NULL)
0326         large_free(wksp);
0327     return err;
0328 }
0329 
0330 #ifndef UNZSTD_PREBOOT
0331 STATIC int INIT unzstd(unsigned char *buf, long len,
0332                long (*fill)(void*, unsigned long),
0333                long (*flush)(void*, unsigned long),
0334                unsigned char *out_buf,
0335                long *pos,
0336                void (*error)(char *x))
0337 {
0338     return __unzstd(buf, len, fill, flush, out_buf, 0, pos, error);
0339 }
0340 #else
0341 STATIC int INIT __decompress(unsigned char *buf, long len,
0342                  long (*fill)(void*, unsigned long),
0343                  long (*flush)(void*, unsigned long),
0344                  unsigned char *out_buf, long out_len,
0345                  long *pos,
0346                  void (*error)(char *x))
0347 {
0348     return __unzstd(buf, len, fill, flush, out_buf, out_len, pos, error);
0349 }
0350 #endif