diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index f26e5ff61..c8cfa716e 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -115,6 +115,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) dctx->bmi2 = ZSTD_cpuid_bmi2(ZSTD_cpuid()); dctx->outBufferMode = ZSTD_obm_buffered; dctx->forceIgnoreChecksum = ZSTD_d_validateChecksum; + dctx->validateChecksum = 1; #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION dctx->dictContentEndForFuzzing = NULL; #endif @@ -447,7 +448,8 @@ static size_t ZSTD_decodeFrameHeader(ZSTD_DCtx* dctx, const void* src, size_t he RETURN_ERROR_IF(dctx->fParams.dictID && (dctx->dictID != dctx->fParams.dictID), dictionary_wrong, ""); #endif - if (dctx->fParams.checksumFlag && !dctx->forceIgnoreChecksum) XXH64_reset(&dctx->xxhState, 0); + dctx->validateChecksum = (dctx->fParams.checksumFlag && !dctx->forceIgnoreChecksum) ? 1 : 0; + if (dctx->validateChecksum) XXH64_reset(&dctx->xxhState, 0); return 0; } @@ -662,7 +664,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, } if (ZSTD_isError(decodedSize)) return decodedSize; - if (dctx->fParams.checksumFlag && !dctx->forceIgnoreChecksum) + if (dctx->validateChecksum) XXH64_update(&dctx->xxhState, op, decodedSize); if (decodedSize != 0) op += decodedSize; @@ -980,7 +982,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c RETURN_ERROR_IF(rSize > dctx->fParams.blockSizeMax, corruption_detected, "Decompressed Block Size Exceeds Maximum"); DEBUGLOG(5, "ZSTD_decompressContinue: decoded size from block : %u", (unsigned)rSize); dctx->decodedSize += rSize; - if (dctx->fParams.checksumFlag && !dctx->forceIgnoreChecksum) XXH64_update(&dctx->xxhState, dst, rSize); + if (dctx->validateChecksum) XXH64_update(&dctx->xxhState, dst, rSize); dctx->previousDstEnd = (char*)dst + rSize; /* Stay on the same stage until we are finished streaming the block. */ @@ -1011,7 +1013,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c case ZSTDds_checkChecksum: assert(srcSize == 4); /* guaranteed by dctx->expected */ { - if (!dctx->forceIgnoreChecksum) { + if (dctx->validateChecksum) { U32 const h32 = (U32)XXH64_digest(&dctx->xxhState); U32 const check32 = MEM_readLE32(src); DEBUGLOG(4, "ZSTD_decompressContinue: checksum : calculated %08X :: %08X read", (unsigned)h32, (unsigned)check32); diff --git a/lib/decompress/zstd_decompress_internal.h b/lib/decompress/zstd_decompress_internal.h index 178023806..891d1bedb 100644 --- a/lib/decompress/zstd_decompress_internal.h +++ b/lib/decompress/zstd_decompress_internal.h @@ -122,7 +122,8 @@ struct ZSTD_DCtx_s XXH64_state_t xxhState; size_t headerSize; ZSTD_format_e format; - ZSTD_forceIgnoreChecksum_e forceIgnoreChecksum; /* if == 1, will ignore checksums in compressed frame */ + ZSTD_forceIgnoreChecksum_e forceIgnoreChecksum; /* User specified: if == 1, will ignore checksums in compressed frame. Default == 0 */ + U32 validateChecksum; /* if == 1, will validate checksum. Is == 1 if (fParams.checksumFlag == 1) and (forceIgnoreChecksum == 0). */ const BYTE* litPtr; ZSTD_customMem customMem; size_t litSize; diff --git a/programs/fileio.c b/programs/fileio.c index f2b8447af..1970f6cb2 100644 --- a/programs/fileio.c +++ b/programs/fileio.c @@ -1754,11 +1754,7 @@ static dRess_t FIO_createDResources(FIO_prefs_t* const prefs, const char* dictFi if (ress.dctx==NULL) EXM_THROW(60, "Error: %s : can't create ZSTD_DStream", strerror(errno)); CHECK( ZSTD_DCtx_setMaxWindowSize(ress.dctx, prefs->memLimit) ); - if (!prefs->checksumFlag) { - CHECK( ZSTD_DCtx_setParameter(ress.dctx, ZSTD_d_forceIgnoreChecksum, ZSTD_d_ignoreChecksum)); - } else { - CHECK( ZSTD_DCtx_setParameter(ress.dctx, ZSTD_d_forceIgnoreChecksum, ZSTD_d_validateChecksum)); - } + CHECK( ZSTD_DCtx_setParameter(ress.dctx, ZSTD_d_forceIgnoreChecksum, !prefs->checksumFlag)); ress.srcBufferSize = ZSTD_DStreamInSize(); ress.srcBuffer = malloc(ress.srcBufferSize); diff --git a/tests/fuzzer.c b/tests/fuzzer.c index d85ccef5d..e5c3e6e3b 100644 --- a/tests/fuzzer.c +++ b/tests/fuzzer.c @@ -571,8 +571,8 @@ static int basicUnitTests(U32 const seed, double compressibility) r = ZSTD_decompress(decodedBuffer, CNBuffSize, compressedBuffer, cSize); if (!ZSTD_isError(r)) goto _output_error; if (ZSTD_getErrorCode(r) != ZSTD_error_checksum_wrong) goto _output_error; - - CHECK_Z(ZSTD_DCtx_setForceIgnoreChecksum(dctx, ZSTD_d_ignoreChecksum)); + + CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_forceIgnoreChecksum, ZSTD_d_ignoreChecksum)); r = ZSTD_decompressDCtx(dctx, decodedBuffer, CNBuffSize, compressedBuffer, cSize-1); if (!ZSTD_isError(r)) goto _output_error; /* wrong checksum size should still throw error */ r = ZSTD_decompressDCtx(dctx, decodedBuffer, CNBuffSize, compressedBuffer, cSize); diff --git a/tests/playTests.sh b/tests/playTests.sh index affd64eb0..18f6f3fb9 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -261,11 +261,12 @@ zstd tmp -c --compress-literals -19 | zstd -t zstd -b --fast=1 -i0e1 tmp --compress-literals zstd -b --fast=1 -i0e1 tmp --no-compress-literals println "test: --no-check for decompression" -zstd -f tmp -o tmp.zst --check -zstd -f tmp -o tmp1.zst --no-check -printf '\xDE\xAD\xBE\xEF' | dd of=tmp.zst bs=1 seek=$(($(wc -c <"tmp.zst") - 4)) count=4 conv=notrunc # corrupt checksum in tmp +zstd -f tmp -o tmp_corrupt.zst --check +zstd -f tmp -o tmp.zst --no-check +printf '\xDE\xAD\xBE\xEF' | dd of=tmp_corrupt.zst bs=1 seek=$(($(wc -c < "tmp_corrupt.zst") - 4)) count=4 conv=notrunc # corrupt checksum in tmp +zstd -d -f tmp_corrupt.zst --no-check +zstd -d -f tmp_corrupt.zst --check --no-check # final flag overrides zstd -d -f tmp.zst --no-check -zstd -d -f tmp1.zst --no-check println "\n===> zstdgrep tests" ln -sf "$ZSTD_BIN" zstdcat