diff --git a/doc/educational_decoder/zstd_decompress.c b/doc/educational_decoder/zstd_decompress.c index 7c8d8114d..93c346312 100644 --- a/doc/educational_decoder/zstd_decompress.c +++ b/doc/educational_decoder/zstd_decompress.c @@ -28,6 +28,7 @@ size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, /// Get the decompressed size of an input stream so memory can be allocated in /// advance /// Returns -1 if the size can't be determined +/// Assumes decompression of a single frame size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len); /******* UTILITY MACROS AND TYPES *********************************************/ @@ -396,9 +397,9 @@ size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, // Multiple frames can be appended into a single file or stream. A frame is // totally independent, has a defined beginning and end, and a set of // parameters which tells the decoder how to decompress it." - while (IO_istream_len(&in) > 0) { - decode_frame(&out, &in, &parsed_dict); - } + + /* this decoder assumes decompression of a single frame */ + decode_frame(&out, &in, &parsed_dict); free_dictionary(&parsed_dict); @@ -424,30 +425,6 @@ static void decompress_data(frame_context_t *const ctx, ostream_t *const out, static void decode_frame(ostream_t *const out, istream_t *const in, const dictionary_t *const dict) { const u32 magic_number = IO_read_bits(in, 32); - - // Skippable frame - // - // "Magic_Number - // - // 4 Bytes, little-endian format. Value : 0x184D2A5?, which means any value - // from 0x184D2A50 to 0x184D2A5F. All 16 values are valid to identify a - // skippable frame." - if ((magic_number & ~0xFU) == 0x184D2A50U) { - // "Skippable frames allow the insertion of user-defined data into a - // flow of concatenated frames. Its design is pretty straightforward, - // with the sole objective to allow the decoder to quickly skip over - // user-defined data and continue decoding. - // - // Skippable frames defined in this specification are compatible with - // LZ4 ones." - const size_t frame_size = IO_read_bits(in, 32); - - // skip over frame - IO_advance_input(in, frame_size); - - return; - } - // Zstandard frame // // "Magic_Number @@ -460,8 +437,8 @@ static void decode_frame(ostream_t *const out, istream_t *const in, return; } - // not a real frame - ERROR("Invalid magic number"); + // not a real frame or a skippable frame + ERROR("Tried to decode non-ZSTD frame"); } /// Decode a frame that contains compressed data. Not all frames do as there @@ -1420,28 +1397,17 @@ static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, /******* END SEQUENCE EXECUTION ***********************************************/ /******* OUTPUT SIZE COUNTING *************************************************/ -static void traverse_frame(const frame_header_t *const header, istream_t *const in); - /// Get the decompressed size of an input stream so memory can be allocated in /// advance. -/// This is more complex than the implementation in the reference -/// implementation, as this API allows for the decompression of multiple -/// concatenated frames. +/// This implementation assumes `src` points to a single ZSTD-compressed frame size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { istream_t in = IO_make_istream(src, src_len); - size_t dst_size = 0; - // Each frame header only gives us the size of its frame, so iterate over - // all - // frames - while (IO_istream_len(&in) > 0) { + // get decompressed size from ZSTD frame header + { const u32 magic_number = IO_read_bits(&in, 32); - if ((magic_number & ~0xFU) == 0x184D2A50U) { - // skippable frame, this has no impact on output size - const size_t frame_size = IO_read_bits(&in, 32); - IO_advance_input(&in, frame_size); - } else if (magic_number == 0xFD2FB528U) { + if (magic_number == 0xFD2FB528U) { // ZSTD frame frame_header_t header; parse_frame_header(&header, &in); @@ -1451,54 +1417,13 @@ size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { return -1; } - dst_size += header.frame_content_size; - - // Consume the input from the frame to reach the start of the next - traverse_frame(&header, &in); + return header.frame_content_size; } else { - // not a real frame - ERROR("Invalid magic number"); + // not a real frame or skippable frame + ERROR("ZSTD frame magic number did not match"); } } - - return dst_size; } - -/// Iterate over each block in a frame to find the end of it, to get to the -/// start of the next frame -static void traverse_frame(const frame_header_t *const header, istream_t *const in) { - int last_block = 0; - - do { - // Parse the block header - last_block = IO_read_bits(in, 1); - const int block_type = IO_read_bits(in, 2); - const size_t block_len = IO_read_bits(in, 21); - - switch (block_type) { - case 0: // Raw block, block_len bytes - IO_advance_input(in, block_len); - break; - case 1: // RLE block, 1 byte - IO_advance_input(in, 1); - break; - case 2: // Compressed block, compressed size is block_len - IO_advance_input(in, block_len); - break; - case 3: - // Reserved block type - CORRUPTION(); - break; - default: - IMPOSSIBLE(); - } - } while (!last_block); - - if (header->content_checksum_flag) { - IO_advance_input(in, 4); - } -} - /******* END OUTPUT SIZE COUNTING *********************************************/ /******* DICTIONARY PARSING ***************************************************/ @@ -2355,4 +2280,3 @@ static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); } /******* END FSE PRIMITIVES ***************************************************/ -