diff --git a/zlibWrapper/zstd_zlibwrapper.c b/zlibWrapper/zstd_zlibwrapper.c index c2efe380e..06667b610 100644 --- a/zlibWrapper/zstd_zlibwrapper.c +++ b/zlibWrapper/zstd_zlibwrapper.c @@ -344,10 +344,10 @@ ZEXTERN int ZEXPORT z_inflateInit_ OF((z_streamp strm, { ZWRAP_DCtx* zwd = ZWRAP_createDCtx(strm); LOG_WRAPPER("- inflateInit\n"); - if (zwd == NULL) return Z_MEM_ERROR; + if (zwd == NULL) { strm->state = NULL; return Z_MEM_ERROR; } zwd->version = zwd->customMem.customAlloc(zwd->customMem.opaque, strlen(version) + 1); - if (zwd->version == NULL) { ZWRAP_freeDCtx(zwd); return Z_MEM_ERROR; } + if (zwd->version == NULL) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return Z_MEM_ERROR; } strcpy(zwd->version, version); zwd->stream_size = stream_size; @@ -372,8 +372,6 @@ ZEXTERN int ZEXPORT z_inflateInit2_ OF((z_streamp strm, int windowBits, } - - ZEXTERN int ZEXPORT z_inflateSetDictionary OF((z_streamp strm, const Bytef *dictionary, uInt dictLength)) @@ -382,9 +380,11 @@ ZEXTERN int ZEXPORT z_inflateSetDictionary OF((z_streamp strm, return inflateSetDictionary(strm, dictionary, dictLength); LOG_WRAPPER("- inflateSetDictionary\n"); - { ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state; - size_t errorCode = ZBUFF_decompressInitDictionary(zwd->zbd, dictionary, dictLength); - if (ZSTD_isError(errorCode)) return Z_MEM_ERROR; + { size_t errorCode; + ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state; + if (strm->state == NULL) return Z_MEM_ERROR; + errorCode = ZBUFF_decompressInitDictionary(zwd->zbd, dictionary, dictLength); + if (ZSTD_isError(errorCode)) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return Z_MEM_ERROR; } if (strm->total_in == ZSTD_frameHeaderSize_min) { size_t dstCapacity = 0; @@ -393,6 +393,7 @@ ZEXTERN int ZEXPORT z_inflateSetDictionary OF((z_streamp strm, LOG_WRAPPER("ZBUFF_decompressContinue3 errorCode=%d srcSize=%d dstCapacity=%d\n", (int)errorCode, (int)srcSize, (int)dstCapacity); if (dstCapacity > 0 || ZSTD_isError(errorCode)) { LOG_WRAPPER("ERROR: ZBUFF_decompressContinue %s\n", ZSTD_getErrorName(errorCode)); + ZWRAP_freeDCtx(zwd); strm->state = NULL; return Z_MEM_ERROR; } } @@ -410,6 +411,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) if (strm->avail_in > 0) { size_t errorCode, dstCapacity, srcSize; ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state; + if (strm->state == NULL) return Z_MEM_ERROR; LOG_WRAPPER("inflate avail_in=%d avail_out=%d total_in=%d total_out=%d\n", (int)strm->avail_in, (int)strm->avail_out, (int)strm->total_in, (int)strm->total_out); if (strm->total_in < ZWRAP_HEADERSIZE) { @@ -432,7 +434,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) else errorCode = inflateInit_(strm, zwd->version, zwd->stream_size); LOG_WRAPPER("ZLIB inflateInit errorCode=%d\n", (int)errorCode); - if (errorCode != Z_OK) return errorCode; + if (errorCode != Z_OK) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return errorCode; } /* inflate header */ strm->next_in = (unsigned char*)zwd->headerBuf; @@ -440,8 +442,8 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) strm->avail_out = 0; errorCode = inflate(strm, Z_NO_FLUSH); LOG_WRAPPER("ZLIB inflate errorCode=%d strm->avail_in=%d\n", (int)errorCode, (int)strm->avail_in); - if (errorCode != Z_OK) return errorCode; - if (strm->avail_in > 0) return Z_MEM_ERROR; + if (errorCode != Z_OK) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return errorCode; } + if (strm->avail_in > 0) goto error; strm->next_in = strm2.next_in; strm->avail_in = strm2.avail_in; @@ -450,17 +452,17 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) strm->reserved = 0; /* mark as zlib stream */ errorCode = ZWRAP_freeDCtx(zwd); - if (ZSTD_isError(errorCode)) return Z_MEM_ERROR; + if (ZSTD_isError(errorCode)) goto error; if (flush == Z_INFLATE_SYNC) return inflateSync(strm); return inflate(strm, flush); } zwd->zbd = ZBUFF_createDCtx_advanced(zwd->customMem); - if (zwd->zbd == NULL) { ZWRAP_freeDCtx(zwd); return Z_MEM_ERROR; } + if (zwd->zbd == NULL) goto error; errorCode = ZBUFF_decompressInit(zwd->zbd); - if (ZSTD_isError(errorCode)) return Z_MEM_ERROR; + if (ZSTD_isError(errorCode)) goto error; srcSize = ZWRAP_HEADERSIZE; dstCapacity = 0; @@ -468,7 +470,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) LOG_WRAPPER("ZBUFF_decompressContinue1 errorCode=%d srcSize=%d dstCapacity=%d\n", (int)errorCode, (int)srcSize, (int)dstCapacity); if (ZSTD_isError(errorCode)) { LOG_WRAPPER("ERROR: ZBUFF_decompressContinue %s\n", ZSTD_getErrorName(errorCode)); - return Z_MEM_ERROR; + goto error; } if (strm->avail_in == 0) return Z_OK; } @@ -480,7 +482,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) if (ZSTD_isError(errorCode)) { LOG_WRAPPER("ERROR: ZBUFF_decompressContinue %s\n", ZSTD_getErrorName(errorCode)); zwd->errorCount++; - return (zwd->errorCount<=1) ? Z_NEED_DICT : Z_MEM_ERROR; + if (zwd->errorCount<=1) return Z_NEED_DICT; else goto error; } strm->next_out += dstCapacity; strm->total_out += dstCapacity; @@ -489,6 +491,11 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush)) strm->next_in += srcSize; strm->avail_in -= srcSize; if (errorCode == 0) return Z_STREAM_END; + return Z_OK; +error: + ZWRAP_freeDCtx(zwd); + strm->state = NULL; + return Z_MEM_ERROR; } return Z_OK; } @@ -503,6 +510,7 @@ ZEXTERN int ZEXPORT z_inflateEnd OF((z_streamp strm)) LOG_WRAPPER("- inflateEnd total_in=%d total_out=%d\n", (int)(strm->total_in), (int)(strm->total_out)); { ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state; size_t const errorCode = ZWRAP_freeDCtx(zwd); + strm->state = NULL; if (ZSTD_isError(errorCode)) return Z_MEM_ERROR; } return ret;