diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index a2c046ec7..8a86cd5c6 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -1251,7 +1251,7 @@ static void ZSTD_assertEqualCParams(ZSTD_compressionParameters cParams1, assert(cParams1.strategy == cParams2.strategy); } -static void ZSTD_reset_compressedBlockState(ZSTD_compressedBlockState_t* bs) +void ZSTD_reset_compressedBlockState(ZSTD_compressedBlockState_t* bs) { int i; for (i = 0; i < ZSTD_REP_NUM; ++i) @@ -2772,37 +2772,13 @@ static size_t ZSTD_checkDictNCount(short* normalizedCounter, unsigned dictMaxSym return 0; } - -/* Dictionary format : - * See : - * https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format - */ -/*! ZSTD_loadZstdDictionary() : - * @return : dictID, or an error code - * assumptions : magic number supposed already checked - * dictSize supposed >= 8 - */ -static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, - ZSTD_matchState_t* ms, - ZSTD_cwksp* ws, - ZSTD_CCtx_params const* params, - const void* dict, size_t dictSize, - ZSTD_dictTableLoadMethod_e dtlm, - void* workspace) +size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, + short* offcodeNCount, unsigned* offcodeMaxValue, + const void* const dict, size_t dictSize) { - const BYTE* dictPtr = (const BYTE*)dict; + const BYTE* dictPtr = (const BYTE*)dict; /* skip magic num and dict ID */ const BYTE* const dictEnd = dictPtr + dictSize; - short offcodeNCount[MaxOff+1]; - unsigned offcodeMaxValue = MaxOff; - size_t dictID; - - ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<= 8); - assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY); - - dictPtr += 4; /* skip magic number */ - dictID = params->fParams.noDictIDFlag ? 0 : MEM_readLE32(dictPtr); - dictPtr += 4; + dictPtr += 8; { unsigned maxSymbolValue = 255; size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr, dictEnd-dictPtr); @@ -2812,7 +2788,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, } { unsigned offcodeLog; - size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, &offcodeMaxValue, &offcodeLog, dictPtr, dictEnd-dictPtr); + size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, offcodeMaxValue, &offcodeLog, dictPtr, dictEnd-dictPtr); RETURN_ERROR_IF(FSE_isError(offcodeHeaderSize), dictionary_corrupted); RETURN_ERROR_IF(offcodeLog > OffFSELog, dictionary_corrupted); /* Defer checking offcodeMaxValue because we need to know the size of the dictionary content */ @@ -2861,6 +2837,42 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, bs->rep[2] = MEM_readLE32(dictPtr+8); dictPtr += 12; + return dictPtr - (const BYTE*)dict; +} + +/* Dictionary format : + * See : + * https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format + */ +/*! ZSTD_loadZstdDictionary() : + * @return : dictID, or an error code + * assumptions : magic number supposed already checked + * dictSize supposed >= 8 + */ +static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, + ZSTD_matchState_t* ms, + ZSTD_cwksp* ws, + ZSTD_CCtx_params const* params, + const void* dict, size_t dictSize, + ZSTD_dictTableLoadMethod_e dtlm, + void* workspace) +{ + const BYTE* dictPtr = (const BYTE*)dict; + const BYTE* const dictEnd = dictPtr + dictSize; + short offcodeNCount[MaxOff+1]; + unsigned offcodeMaxValue = MaxOff; + size_t dictID; + size_t eSize; + + ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<= 8); + assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY); + + dictID = params->fParams.noDictIDFlag ? 0 : MEM_readLE32(dictPtr + 4 /* skip magic number */ ); + eSize = ZSTD_loadCEntropy(bs, workspace, offcodeNCount, &offcodeMaxValue, dict, dictSize); + FORWARD_IF_ERROR(eSize); + dictPtr += eSize; + { size_t const dictContentSize = (size_t)(dictEnd - dictPtr); U32 offcodeMax = MaxOff; if (dictContentSize <= ((U32)-1) - 128 KB) { diff --git a/lib/compress/zstd_compress_internal.h b/lib/compress/zstd_compress_internal.h index 4ed09890e..b0a309884 100644 --- a/lib/compress/zstd_compress_internal.h +++ b/lib/compress/zstd_compress_internal.h @@ -931,6 +931,21 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max) } #endif +/* =============================================================== + * Shared internal declarations + * These prototypes may be called from sources not in lib/compress + * =============================================================== */ + +/* ZSTD_loadCEntropy() : + * dict : must point at beginning of a valid zstd dictionary. + * return : size of dictionary header (size of magic number + dict ID + entropy tables) + * assumptions : magic number supposed already checked + * and dictSize >= 8 */ +size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, + short* offcodeNCount, unsigned* offcodeMaxValue, + const void* const dict, size_t dictSize); + +void ZSTD_reset_compressedBlockState(ZSTD_compressedBlockState_t* bs); /* ============================================================== * Private declarations diff --git a/lib/decompress/zstd_decompress_internal.h b/lib/decompress/zstd_decompress_internal.h index ccbdfa090..99eab854c 100644 --- a/lib/decompress/zstd_decompress_internal.h +++ b/lib/decompress/zstd_decompress_internal.h @@ -160,7 +160,7 @@ struct ZSTD_DCtx_s /*! ZSTD_loadDEntropy() : * dict : must point at beginning of a valid zstd dictionary. - * @return : size of entropy tables read */ + * @return : size of dictionary header (size of magic number + dict ID + entropy tables) */ size_t ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, const void* const dict, size_t const dictSize); diff --git a/lib/dictBuilder/zdict.c b/lib/dictBuilder/zdict.c index 1e7f83432..344ab446b 100644 --- a/lib/dictBuilder/zdict.c +++ b/lib/dictBuilder/zdict.c @@ -48,6 +48,7 @@ # define ZDICT_STATIC_LINKING_ONLY #endif #include "zdict.h" +#include "compress/zstd_compress_internal.h" /* ZSTD_loadCEntropy() */ /*-************************************* @@ -99,6 +100,29 @@ unsigned ZDICT_getDictID(const void* dictBuffer, size_t dictSize) return MEM_readLE32((const char*)dictBuffer + 4); } +size_t ZDICT_getDictHeaderSize(const void* dictBuffer, size_t dictSize) +{ + size_t headerSize; + if (dictSize <= 8 || MEM_readLE32(dictBuffer) != ZSTD_MAGIC_DICTIONARY) return ERROR(dictionary_corrupted); + + { unsigned offcodeMaxValue = MaxOff; + ZSTD_compressedBlockState_t* bs = (ZSTD_compressedBlockState_t*)malloc(sizeof(ZSTD_compressedBlockState_t)); + U32* wksp = (U32*)malloc(HUF_WORKSPACE_SIZE); + short* offcodeNCount = (short*)malloc((MaxOff+1)*sizeof(short)); + if (!bs || !wksp || !offcodeNCount) { + headerSize = ERROR(memory_allocation); + } else { + ZSTD_reset_compressedBlockState(bs); + headerSize = ZSTD_loadCEntropy(bs, wksp, offcodeNCount, &offcodeMaxValue, dictBuffer, dictSize); + } + + free(bs); + free(wksp); + free(offcodeNCount); + } + + return headerSize; +} /*-******************************************************** * Dictionary training functions diff --git a/lib/dictBuilder/zdict.h b/lib/dictBuilder/zdict.h index 37978ecdf..1313bd214 100644 --- a/lib/dictBuilder/zdict.h +++ b/lib/dictBuilder/zdict.h @@ -64,6 +64,7 @@ ZDICTLIB_API size_t ZDICT_trainFromBuffer(void* dictBuffer, size_t dictBufferCap /*====== Helper functions ======*/ ZDICTLIB_API unsigned ZDICT_getDictID(const void* dictBuffer, size_t dictSize); /**< extracts dictID; @return zero if error (not a valid dictionary) */ +ZDICTLIB_API size_t ZDICT_getDictHeaderSize(const void* dictBuffer, size_t dictSize); /* returns dict header size; returns a ZSTD error code on failure */ ZDICTLIB_API unsigned ZDICT_isError(size_t errorCode); ZDICTLIB_API const char* ZDICT_getErrorName(size_t errorCode); diff --git a/tests/fuzzer.c b/tests/fuzzer.c index 65dc6f50a..a61667ed3 100644 --- a/tests/fuzzer.c +++ b/tests/fuzzer.c @@ -1159,6 +1159,7 @@ static int basicUnitTests(U32 const seed, double compressibility) size_t* const samplesSizes = (size_t*) malloc(nbSamples * sizeof(size_t)); size_t dictSize; U32 dictID; + size_t dictHeaderSize; if (dictBuffer==NULL || samplesSizes==NULL) { free(dictBuffer); @@ -1248,6 +1249,29 @@ static int basicUnitTests(U32 const seed, double compressibility) if (dictID==0) goto _output_error; DISPLAYLEVEL(3, "OK : %u \n", (unsigned)dictID); + DISPLAYLEVEL(3, "test%3i : check dict header size no error : ", testNb++); + dictHeaderSize = ZDICT_getDictHeaderSize(dictBuffer, dictSize); + if (dictHeaderSize==0) goto _output_error; + DISPLAYLEVEL(3, "OK : %u \n", (unsigned)dictHeaderSize); + + DISPLAYLEVEL(3, "test%3i : check dict header size correctness : ", testNb++); + { unsigned char const dictBufferFixed[144] = { 0x37, 0xa4, 0x30, 0xec, 0x63, 0x00, 0x00, 0x00, 0x08, 0x10, 0x00, 0x1f, + 0x0f, 0x00, 0x28, 0xe5, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x80, 0x0f, 0x9e, 0x0f, 0x00, 0x00, 0x24, 0x40, 0x80, 0x00, 0x01, + 0x02, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0xde, 0x08, + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, + 0x08, 0x08, 0x08, 0x08, 0xbc, 0xe1, 0x4b, 0x92, 0x0e, 0xb4, 0x7b, 0x18, + 0x86, 0x61, 0x18, 0xc6, 0x18, 0x63, 0x8c, 0x31, 0xc6, 0x18, 0x63, 0x8c, + 0x31, 0x66, 0x66, 0x66, 0x66, 0xb6, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x20, 0x73, 0x6f, 0x64, 0x61, + 0x6c, 0x65, 0x73, 0x20, 0x74, 0x6f, 0x72, 0x74, 0x6f, 0x72, 0x20, 0x65, + 0x6c, 0x65, 0x69, 0x66, 0x65, 0x6e, 0x64, 0x2e, 0x20, 0x41, 0x6c, 0x69 }; + dictHeaderSize = ZDICT_getDictHeaderSize(dictBufferFixed, 144); + if (dictHeaderSize != 115) goto _output_error; + } + DISPLAYLEVEL(3, "OK : %u \n", (unsigned)dictHeaderSize); + DISPLAYLEVEL(3, "test%3i : compress with dictionary : ", testNb++); cSize = ZSTD_compress_usingDict(cctx, compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize,