0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 #include "mem.h"
0019 #include "error_private.h" /* ERR_*, ERROR */
0020 #define FSE_STATIC_LINKING_ONLY
0021 #include "fse.h"
0022 #define HUF_STATIC_LINKING_ONLY
0023 #include "huf.h"
0024
0025
0026
0027 unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
0028
0029
0030
0031 unsigned FSE_isError(size_t code) { return ERR_isError(code); }
0032 const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); }
0033
0034 unsigned HUF_isError(size_t code) { return ERR_isError(code); }
0035 const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); }
0036
0037
0038
0039
0040
0041 static U32 FSE_ctz(U32 val)
0042 {
0043 assert(val != 0);
0044 {
0045 # if (__GNUC__ >= 3)
0046 return __builtin_ctz(val);
0047 # else
0048 U32 count = 0;
0049 while ((val & 1) == 0) {
0050 val >>= 1;
0051 ++count;
0052 }
0053 return count;
0054 # endif
0055 }
0056 }
0057
0058 FORCE_INLINE_TEMPLATE
0059 size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
0060 const void* headerBuffer, size_t hbSize)
0061 {
0062 const BYTE* const istart = (const BYTE*) headerBuffer;
0063 const BYTE* const iend = istart + hbSize;
0064 const BYTE* ip = istart;
0065 int nbBits;
0066 int remaining;
0067 int threshold;
0068 U32 bitStream;
0069 int bitCount;
0070 unsigned charnum = 0;
0071 unsigned const maxSV1 = *maxSVPtr + 1;
0072 int previous0 = 0;
0073
0074 if (hbSize < 8) {
0075
0076 char buffer[8] = {0};
0077 ZSTD_memcpy(buffer, headerBuffer, hbSize);
0078 { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr,
0079 buffer, sizeof(buffer));
0080 if (FSE_isError(countSize)) return countSize;
0081 if (countSize > hbSize) return ERROR(corruption_detected);
0082 return countSize;
0083 } }
0084 assert(hbSize >= 8);
0085
0086
0087 ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0]));
0088 bitStream = MEM_readLE32(ip);
0089 nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG;
0090 if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
0091 bitStream >>= 4;
0092 bitCount = 4;
0093 *tableLogPtr = nbBits;
0094 remaining = (1<<nbBits)+1;
0095 threshold = 1<<nbBits;
0096 nbBits++;
0097
0098 for (;;) {
0099 if (previous0) {
0100
0101
0102
0103
0104
0105 int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1;
0106 while (repeats >= 12) {
0107 charnum += 3 * 12;
0108 if (LIKELY(ip <= iend-7)) {
0109 ip += 3;
0110 } else {
0111 bitCount -= (int)(8 * (iend - 7 - ip));
0112 bitCount &= 31;
0113 ip = iend - 4;
0114 }
0115 bitStream = MEM_readLE32(ip) >> bitCount;
0116 repeats = FSE_ctz(~bitStream | 0x80000000) >> 1;
0117 }
0118 charnum += 3 * repeats;
0119 bitStream >>= 2 * repeats;
0120 bitCount += 2 * repeats;
0121
0122
0123 assert((bitStream & 3) < 3);
0124 charnum += bitStream & 3;
0125 bitCount += 2;
0126
0127
0128
0129
0130
0131 if (charnum >= maxSV1) break;
0132
0133
0134
0135
0136
0137 if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
0138 assert((bitCount >> 3) <= 3);
0139 ip += bitCount>>3;
0140 bitCount &= 7;
0141 } else {
0142 bitCount -= (int)(8 * (iend - 4 - ip));
0143 bitCount &= 31;
0144 ip = iend - 4;
0145 }
0146 bitStream = MEM_readLE32(ip) >> bitCount;
0147 }
0148 {
0149 int const max = (2*threshold-1) - remaining;
0150 int count;
0151
0152 if ((bitStream & (threshold-1)) < (U32)max) {
0153 count = bitStream & (threshold-1);
0154 bitCount += nbBits-1;
0155 } else {
0156 count = bitStream & (2*threshold-1);
0157 if (count >= threshold) count -= max;
0158 bitCount += nbBits;
0159 }
0160
0161 count--;
0162
0163
0164
0165 if (count >= 0) {
0166 remaining -= count;
0167 } else {
0168 assert(count == -1);
0169 remaining += count;
0170 }
0171 normalizedCounter[charnum++] = (short)count;
0172 previous0 = !count;
0173
0174 assert(threshold > 1);
0175 if (remaining < threshold) {
0176
0177
0178
0179
0180 if (remaining <= 1) break;
0181 nbBits = BIT_highbit32(remaining) + 1;
0182 threshold = 1 << (nbBits - 1);
0183 }
0184 if (charnum >= maxSV1) break;
0185
0186 if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
0187 ip += bitCount>>3;
0188 bitCount &= 7;
0189 } else {
0190 bitCount -= (int)(8 * (iend - 4 - ip));
0191 bitCount &= 31;
0192 ip = iend - 4;
0193 }
0194 bitStream = MEM_readLE32(ip) >> bitCount;
0195 } }
0196 if (remaining != 1) return ERROR(corruption_detected);
0197
0198 if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall);
0199 if (bitCount > 32) return ERROR(corruption_detected);
0200 *maxSVPtr = charnum-1;
0201
0202 ip += (bitCount+7)>>3;
0203 return ip-istart;
0204 }
0205
0206
0207 static size_t FSE_readNCount_body_default(
0208 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
0209 const void* headerBuffer, size_t hbSize)
0210 {
0211 return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
0212 }
0213
0214 #if DYNAMIC_BMI2
0215 TARGET_ATTRIBUTE("bmi2") static size_t FSE_readNCount_body_bmi2(
0216 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
0217 const void* headerBuffer, size_t hbSize)
0218 {
0219 return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
0220 }
0221 #endif
0222
0223 size_t FSE_readNCount_bmi2(
0224 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
0225 const void* headerBuffer, size_t hbSize, int bmi2)
0226 {
0227 #if DYNAMIC_BMI2
0228 if (bmi2) {
0229 return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
0230 }
0231 #endif
0232 (void)bmi2;
0233 return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
0234 }
0235
0236 size_t FSE_readNCount(
0237 short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
0238 const void* headerBuffer, size_t hbSize)
0239 {
0240 return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, 0);
0241 }
0242
0243
0244
0245
0246
0247
0248
0249
0250
0251 size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
0252 U32* nbSymbolsPtr, U32* tableLogPtr,
0253 const void* src, size_t srcSize)
0254 {
0255 U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32];
0256 return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), 0);
0257 }
0258
0259 FORCE_INLINE_TEMPLATE size_t
0260 HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats,
0261 U32* nbSymbolsPtr, U32* tableLogPtr,
0262 const void* src, size_t srcSize,
0263 void* workSpace, size_t wkspSize,
0264 int bmi2)
0265 {
0266 U32 weightTotal;
0267 const BYTE* ip = (const BYTE*) src;
0268 size_t iSize;
0269 size_t oSize;
0270
0271 if (!srcSize) return ERROR(srcSize_wrong);
0272 iSize = ip[0];
0273
0274
0275 if (iSize >= 128) {
0276 oSize = iSize - 127;
0277 iSize = ((oSize+1)/2);
0278 if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
0279 if (oSize >= hwSize) return ERROR(corruption_detected);
0280 ip += 1;
0281 { U32 n;
0282 for (n=0; n<oSize; n+=2) {
0283 huffWeight[n] = ip[n/2] >> 4;
0284 huffWeight[n+1] = ip[n/2] & 15;
0285 } } }
0286 else {
0287 if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
0288
0289 oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2);
0290 if (FSE_isError(oSize)) return oSize;
0291 }
0292
0293
0294 ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
0295 weightTotal = 0;
0296 { U32 n; for (n=0; n<oSize; n++) {
0297 if (huffWeight[n] >= HUF_TABLELOG_MAX) return ERROR(corruption_detected);
0298 rankStats[huffWeight[n]]++;
0299 weightTotal += (1 << huffWeight[n]) >> 1;
0300 } }
0301 if (weightTotal == 0) return ERROR(corruption_detected);
0302
0303
0304 { U32 const tableLog = BIT_highbit32(weightTotal) + 1;
0305 if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
0306 *tableLogPtr = tableLog;
0307
0308 { U32 const total = 1 << tableLog;
0309 U32 const rest = total - weightTotal;
0310 U32 const verif = 1 << BIT_highbit32(rest);
0311 U32 const lastWeight = BIT_highbit32(rest) + 1;
0312 if (verif != rest) return ERROR(corruption_detected);
0313 huffWeight[oSize] = (BYTE)lastWeight;
0314 rankStats[lastWeight]++;
0315 } }
0316
0317
0318 if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected);
0319
0320
0321 *nbSymbolsPtr = (U32)(oSize+1);
0322 return iSize+1;
0323 }
0324
0325
0326 static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats,
0327 U32* nbSymbolsPtr, U32* tableLogPtr,
0328 const void* src, size_t srcSize,
0329 void* workSpace, size_t wkspSize)
0330 {
0331 return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0);
0332 }
0333
0334 #if DYNAMIC_BMI2
0335 static TARGET_ATTRIBUTE("bmi2") size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats,
0336 U32* nbSymbolsPtr, U32* tableLogPtr,
0337 const void* src, size_t srcSize,
0338 void* workSpace, size_t wkspSize)
0339 {
0340 return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1);
0341 }
0342 #endif
0343
0344 size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats,
0345 U32* nbSymbolsPtr, U32* tableLogPtr,
0346 const void* src, size_t srcSize,
0347 void* workSpace, size_t wkspSize,
0348 int bmi2)
0349 {
0350 #if DYNAMIC_BMI2
0351 if (bmi2) {
0352 return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize);
0353 }
0354 #endif
0355 (void)bmi2;
0356 return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize);
0357 }