From 92ec2ea62f34870d0f9900af68eb816845a3494c Mon Sep 17 00:00:00 2001 From: Sean Purcell Date: Tue, 31 Jan 2017 15:57:18 -0800 Subject: [PATCH] More const's and readability improvements --- contrib/educational_decoder/harness.c | 2 +- contrib/educational_decoder/zstd_decompress.c | 860 ++++++++---------- contrib/educational_decoder/zstd_decompress.h | 12 +- 3 files changed, 410 insertions(+), 464 deletions(-) diff --git a/contrib/educational_decoder/harness.c b/contrib/educational_decoder/harness.c index 42424d4bd..107a16a22 100644 --- a/contrib/educational_decoder/harness.c +++ b/contrib/educational_decoder/harness.c @@ -88,7 +88,7 @@ int main(int argc, char **argv) { decompressed_size = MAX_COMPRESSION_RATIO * input_size; fprintf(stderr, "WARNING: Compressed data does contain decompressed " "size, going to assume the compression ratio is at " - "most %d (decompressed size of at most %lld\n", + "most %d (decompressed size of at most %zu)\n", MAX_COMPRESSION_RATIO, decompressed_size); } output = malloc(decompressed_size); diff --git a/contrib/educational_decoder/zstd_decompress.c b/contrib/educational_decoder/zstd_decompress.c index 90d4a5229..3c1c56730 100644 --- a/contrib/educational_decoder/zstd_decompress.c +++ b/contrib/educational_decoder/zstd_decompress.c @@ -17,17 +17,17 @@ /// Zstandard decompression functions. /// `dst` must point to a space at least as large as the reconstructed output. -size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, - size_t src_len); +size_t ZSTD_decompress(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len); /// If `dict != NULL` and `dict_len >= 8`, does the same thing as /// `ZSTD_decompress` but uses the provided dict -size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, - size_t src_len, const void *dict, - size_t dict_len); +size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len, + const void *const dict, const size_t dict_len); /// Get the decompressed size of an input stream so memory can be allocated in /// advance -size_t ZSTD_get_decompressed_size(const void *src, size_t src_len); +size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len); /******* UTILITY MACROS AND TYPES *********************************************/ // Max block size decompressed size is 128 KB and literal blocks must be smaller @@ -67,23 +67,21 @@ typedef int64_t i64; /*** BITSTREAM OPERATIONS *************/ /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits -static inline u64 read_bits_LE(const u8 *src, int num, size_t offset); +static inline u64 read_bits_LE(const u8 *src, const int num, + const size_t offset); /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so /// it updates `offset` to `offset - bits`, and then reads `bits` bits from /// `src + offset`. If the offset becomes negative, the extra bits at the /// bottom are filled in with `0` bits instead of reading from before `src`. -static inline u64 STREAM_read_bits(const u8 *src, int bits, i64 *offset); +static inline u64 STREAM_read_bits(const u8 *src, const int bits, + i64 *const offset); /*** END BITSTREAM OPERATIONS *********/ /*** BIT COUNTING OPERATIONS **********/ -/// Returns `x`, where `2^x` is the smallest power of 2 greater than or equal to -/// `num`, or `-1` if `num > 2^63` -static inline int log2sup(u64 num); - /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to /// `num`, or `-1` if `num == 0`. -static inline int log2inf(u64 num); +static inline int log2inf(const u64 num); /*** END BIT COUNTING OPERATIONS ******/ /*** HUFFMAN PRIMITIVES ***************/ @@ -101,36 +99,41 @@ typedef struct { } HUF_dtable; /// Decode a single symbol and read in enough bits to refresh the state -static inline u8 HUF_decode_symbol(HUF_dtable *dtable, u16 *state, - const u8 *src, i64 *offset); +static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); /// Read in a full state's worth of bits to initialize it -static inline void HUF_init_state(HUF_dtable *dtable, u16 *state, const u8 *src, - i64 *offset); - -/// Initialize a Huffman decoding table using the table of bit counts provided -static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs); -/// Initialize a Huffman decoding table using the table of weights provided -/// Weights follow the definition provided in the Zstandard specification -static void HUF_init_dtable_usingweights(HUF_dtable *table, u8 *weights, - int num_symbs); +static inline void HUF_init_state(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); /// Decompresses a single Huffman stream, returns the number of bytes decoded. /// `src_len` must be the exact length of the Huffman-coded block. -static size_t HUF_decompress_1stream(HUF_dtable *table, u8 *dst, size_t dst_len, - const u8 *src, size_t src_len); +static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst, + const size_t dst_len, const u8 *src, + size_t src_len); /// Same as previous but decodes 4 streams, formatted as in the Zstandard /// specification. /// `src_len` must be the exact length of the Huffman-coded block. -static size_t HUF_decompress_4stream(HUF_dtable *dtable, u8 *dst, - size_t dst_len, const u8 *src, - size_t src_len); +static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst, + const size_t dst_len, const u8 *const src, + const size_t src_len); + +/// Initialize a Huffman decoding table using the table of bit counts provided +static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, + const int num_symbs); +/// Initialize a Huffman decoding table using the table of weights provided +/// Weights follow the definition provided in the Zstandard specification +static void HUF_init_dtable_usingweights(HUF_dtable *const table, + const u8 *const weights, + const int num_symbs); /// Free the malloc'ed parts of a decoding table -static void HUF_free_dtable(HUF_dtable *dtable); +static void HUF_free_dtable(HUF_dtable *const dtable); /// Deep copy a decoding table, so that it can be used and free'd without /// impacting the source table. -static void HUF_copy_dtable(HUF_dtable *dst, const HUF_dtable *src); +static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src); /*** END HUFFMAN PRIMITIVES ***********/ /*** FSE PRIMITIVES *******************/ @@ -151,46 +154,53 @@ typedef struct { } FSE_dtable; /// Return the symbol for the current state -static inline u8 FSE_peek_symbol(FSE_dtable *dtable, u16 state); +static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, + const u16 state); /// Read the number of bits necessary to update state, update, and shift offset /// back to reflect the bits read -static inline void FSE_update_state(FSE_dtable *dtable, u16 *state, - const u8 *src, i64 *offset); +static inline void FSE_update_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); /// Combine peek and update: decode a symbol and update the state -static inline u8 FSE_decode_symbol(FSE_dtable *dtable, u16 *state, - const u8 *src, i64 *offset); +static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); /// Read bits from the stream to initialize the state and shift offset back -static inline void FSE_init_state(FSE_dtable *dtable, u16 *state, const u8 *src, - i64 *offset); +static inline void FSE_init_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights) /// using an FSE decoding table. `src_len` must be the exact length of the /// block. -static size_t FSE_decompress_interleaved2(FSE_dtable *dtable, u8 *dst, - size_t dst_len, const u8 *src, - size_t src_len); +static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, + u8 *dst, const size_t dst_len, + const u8 *const src, + const size_t src_len); /// Initialize a decoding table using normalized frequencies. -static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, - int num_symbs, int accuracy_log); +static void FSE_init_dtable(FSE_dtable *const dtable, + const i16 *const norm_freqs, const int num_symbs, + const int accuracy_log); /// Decode an FSE header as defined in the Zstandard format specification and /// use the decoded frequencies to initialize a decoding table. -static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, - size_t src_len, int max_accuracy_log); +static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, + const size_t src_len, + const int max_accuracy_log); /// Initialize an FSE table that will always return the same symbol and consume /// 0 bits per symbol, to be used for RLE mode in sequence commands -static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb); +static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb); /// Free the malloc'ed parts of a decoding table -static void FSE_free_dtable(FSE_dtable *dtable); +static void FSE_free_dtable(FSE_dtable *const dtable); /// Deep copy a decoding table, so that it can be used and free'd without /// impacting the source table. -static void FSE_copy_dtable(FSE_dtable *dst, const FSE_dtable *src); +static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src); /*** END FSE PRIMITIVES ***************/ /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/ @@ -291,47 +301,46 @@ typedef struct { /// Accepts a dict argument, which may be NULL indicating no dictionary. /// See /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation -static void decode_frame(io_streams_t *streams, dictionary_t *dict); +static void decode_frame(io_streams_t *const streams, + const dictionary_t *const dict); // Decode data in a compressed block -static void decompress_block(io_streams_t *streams, frame_context_t *ctx, - size_t block_len); +static void decompress_block(io_streams_t *const streams, + frame_context_t *const ctx, + const size_t block_len); // Decode the literals section of a block -static size_t decode_literals(io_streams_t *streams, frame_context_t *ctx, - u8 **literals); +static size_t decode_literals(io_streams_t *const streams, + frame_context_t *const ctx, u8 **const literals); // Decode the sequences part of a block -static size_t decode_sequences(frame_context_t *ctx, const u8 *src, - size_t src_len, sequence_command_t **sequences); +static size_t decode_sequences(frame_context_t *const ctx, const u8 *const src, + const size_t src_len, + sequence_command_t **const sequences); // Execute the decoded sequences on the literals block -static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, - sequence_command_t *sequences, - size_t num_sequences, const u8 *literals, - size_t literals_len); +static void execute_sequences(io_streams_t *const streams, + frame_context_t *const ctx, + const sequence_command_t *const sequences, + const size_t num_sequences, + const u8 *literals, + size_t literals_len); // Parse a provided dictionary blob for use in decompression -static void parse_dictionary(dictionary_t *dict, const u8 *src, size_t src_len); -static void free_dictionary(dictionary_t *dict); +static void parse_dictionary(dictionary_t *const dict, const u8 *const src, + const size_t src_len); +static void free_dictionary(dictionary_t *const dict); /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ -size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, - size_t src_len) { +size_t ZSTD_decompress(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len) { return ZSTD_decompress_with_dict(dst, dst_len, src, src_len, NULL, 0); } -size_t ZSTD_decompress_usingDict(void *_ctx, void *dst, size_t dst_len, - const void *src, size_t src_len, - const void *dict, size_t dict_len) { - // _ctx needed to match ZSTD lib signature - return ZSTD_decompress_with_dict(dst, dst_len, src, src_len, dict, - dict_len); -} - -size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, - size_t src_len, const void *dict, - size_t dict_len) { +size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len, + const void *const dict, + const size_t dict_len) { dictionary_t parsed_dict; memset(&parsed_dict, 0, sizeof(dictionary_t)); // dict_len < 8 is not a valid dictionary @@ -351,21 +360,26 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, /******* FRAME DECODING ******************************************************/ -static void decode_data_frame(io_streams_t *streams, dictionary_t *dict); -static void init_frame_context(io_streams_t *streams, frame_context_t *context, - dictionary_t *dict); -static void free_frame_context(frame_context_t *context); -static void parse_frame_header(frame_header_t *header, const u8 *src, - size_t src_len); -static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict); +static void decode_data_frame(io_streams_t *const streams, + const dictionary_t *const dict); +static void init_frame_context(io_streams_t *const streams, + frame_context_t *const context, + const dictionary_t *const dict); +static void free_frame_context(frame_context_t *const context); +static void parse_frame_header(frame_header_t *const header, + const u8 *const src, const size_t src_len); +static void frame_context_apply_dict(frame_context_t *const ctx, + const dictionary_t *const dict); -static void decompress_data(io_streams_t *streams, frame_context_t *ctx); +static void decompress_data(io_streams_t *const streams, + frame_context_t *const ctx); -static void decode_frame(io_streams_t *streams, dictionary_t *dict) { +static void decode_frame(io_streams_t *const streams, + const dictionary_t *const dict) { if (streams->src_len < 4) { INP_SIZE(); } - u32 magic_number = read_bits_LE(streams->src, 32, 0); + const u32 magic_number = read_bits_LE(streams->src, 32, 0); streams->src += 4; streams->src_len -= 4; @@ -374,7 +388,7 @@ static void decode_frame(io_streams_t *streams, dictionary_t *dict) { if (streams->src_len < 4) { INP_SIZE(); } - size_t frame_size = read_bits_LE(streams->src, 32, 32); + const size_t frame_size = read_bits_LE(streams->src, 32, 32); if (streams->src_len < 4 + frame_size) { INP_SIZE(); @@ -396,7 +410,8 @@ static void decode_frame(io_streams_t *streams, dictionary_t *dict) { /// are skippable frames. /// See /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format -static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) { +static void decode_data_frame(io_streams_t *const streams, + const dictionary_t *const dict) { frame_context_t ctx; // Initialize the context that needs to be carried from block to block @@ -414,8 +429,10 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) { /// Takes the information provided in the header and dictionary, and initializes /// the context for this frame -static void init_frame_context(io_streams_t *streams, frame_context_t *context, - dictionary_t *dict) { +static void init_frame_context(io_streams_t *const streams, + frame_context_t *const context, + const dictionary_t *const dict) { + // Most fields in context are correct when initialized to 0 memset(context, 0x00, sizeof(frame_context_t)); // Parse data from the frame header @@ -432,7 +449,7 @@ static void init_frame_context(io_streams_t *streams, frame_context_t *context, frame_context_apply_dict(context, dict); } -static void free_frame_context(frame_context_t *context) { +static void free_frame_context(frame_context_t *const context) { HUF_free_dtable(&context->literals_dtable); FSE_free_dtable(&context->ll_dtable); @@ -442,20 +459,20 @@ static void free_frame_context(frame_context_t *context) { memset(context, 0, sizeof(frame_context_t)); } -static void parse_frame_header(frame_header_t *header, const u8 *src, - size_t src_len) { +static void parse_frame_header(frame_header_t *const header, + const u8 *const src, const size_t src_len) { if (src_len < 1) { INP_SIZE(); } - u8 descriptor = read_bits_LE(src, 8, 0); + const u8 descriptor = read_bits_LE(src, 8, 0); // decode frame header descriptor into flags - u8 frame_content_size_flag = descriptor >> 6; - u8 single_segment_flag = (descriptor >> 5) & 1; - u8 reserved_bit = (descriptor >> 3) & 1; - u8 content_checksum_flag = (descriptor >> 2) & 1; - u8 dictionary_id_flag = descriptor & 3; + const u8 frame_content_size_flag = descriptor >> 6; + const u8 single_segment_flag = (descriptor >> 5) & 1; + const u8 reserved_bit = (descriptor >> 3) & 1; + const u8 content_checksum_flag = (descriptor >> 2) & 1; + const u8 dictionary_id_flag = descriptor & 3; if (reserved_bit != 0) { CORRUPTION(); @@ -536,7 +553,8 @@ static void parse_frame_header(frame_header_t *header, const u8 *src, /// A dictionary acts as initializing values for the frame context before /// decompression, so we implement it by applying it's predetermined /// tables and content to the context before beginning decompression -static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { +static void frame_context_apply_dict(frame_context_t *const ctx, + const dictionary_t *const dict) { // If the content pointer is NULL then it must be an empty dict if (!dict || !dict->content) return; @@ -574,8 +592,8 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { } /// Decompress the data from a frame block by block -static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { - +static void decompress_data(io_streams_t *const streams, + frame_context_t *const ctx) { int last_block = 0; do { if (streams->src_len < 3) { @@ -583,8 +601,8 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { } // Parse the block header last_block = streams->src[0] & 1; - int block_type = (streams->src[0] >> 1) & 3; - size_t block_len = read_bits_LE(streams->src, 21, 3); + const int block_type = (streams->src[0] >> 1) & 3; + const size_t block_len = read_bits_LE(streams->src, 21, 3); streams->src += 3; streams->src_len -= 3; @@ -656,8 +674,8 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { /******* END FRAME DECODING ***************************************************/ /******* BLOCK DECOMPRESSION **************************************************/ -static void decompress_block(io_streams_t *streams, frame_context_t *ctx, - size_t block_len) { +static void decompress_block(io_streams_t *const streams, frame_context_t *const ctx, + const size_t block_len) { if (streams->src_len < block_len) { INP_SIZE(); } @@ -666,15 +684,15 @@ static void decompress_block(io_streams_t *streams, frame_context_t *ctx, // Part 1: decode the literals block u8 *literals = NULL; - size_t literals_size = decode_literals(streams, ctx, &literals); + const size_t literals_size = decode_literals(streams, ctx, &literals); // Part 2: decode the sequences block if (streams->src > end_of_block) { INP_SIZE(); } - size_t sequences_size = end_of_block - streams->src; + const size_t sequences_size = end_of_block - streams->src; sequence_command_t *sequences = NULL; - size_t num_sequences = + const size_t num_sequences = decode_sequences(ctx, streams->src, sequences_size, &sequences); streams->src += sequences_size; @@ -689,18 +707,22 @@ static void decompress_block(io_streams_t *streams, frame_context_t *ctx, /******* END BLOCK DECOMPRESSION **********************************************/ /******* LITERALS DECODING ****************************************************/ -static size_t decode_literals_simple(io_streams_t *streams, u8 **literals, - int block_type, int size_format); -static size_t decode_literals_compressed(io_streams_t *streams, - frame_context_t *ctx, u8 **literals, - int block_type, int size_format); +static size_t decode_literals_simple(io_streams_t *const streams, + u8 **const literals, const int block_type, + const int size_format); +static size_t decode_literals_compressed(io_streams_t *const streams, + frame_context_t *const ctx, + u8 **const literals, + const int block_type, + const int size_format); static size_t decode_huf_table(const u8 *src, size_t src_len, - HUF_dtable *dtable); -static size_t fse_decode_hufweights(const u8 *src, size_t src_len, u8 *weights, - int *num_symbs, size_t compressed_size); + HUF_dtable *const dtable); +static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len, + u8 *const weights, int *const num_symbs, + const size_t compressed_size); -static size_t decode_literals(io_streams_t *streams, frame_context_t *ctx, - u8 **literals) { +static size_t decode_literals(io_streams_t *const streams, + frame_context_t *const ctx, u8 **const literals) { if (streams->src_len < 1) { INP_SIZE(); } @@ -720,8 +742,9 @@ static size_t decode_literals(io_streams_t *streams, frame_context_t *ctx, } /// Decodes literals blocks in raw or RLE form -static size_t decode_literals_simple(io_streams_t *streams, u8 **literals, - int block_type, int size_format) { +static size_t decode_literals_simple(io_streams_t *const streams, + u8 **const literals, const int block_type, + const int size_format) { size_t size; switch (size_format) { // These cases are in the form X0 @@ -787,9 +810,11 @@ static size_t decode_literals_simple(io_streams_t *streams, u8 **literals, } /// Decodes Huffman compressed literals -static size_t decode_literals_compressed(io_streams_t *streams, - frame_context_t *ctx, u8 **literals, - int block_type, int size_format) { +static size_t decode_literals_compressed(io_streams_t *const streams, + frame_context_t *const ctx, + u8 **const literals, + const int block_type, + const int size_format) { size_t regenerated_size, compressed_size; // Only size_format=0 has 1 stream, so default to 4 int num_streams = 4; @@ -846,8 +871,8 @@ static size_t decode_literals_compressed(io_streams_t *streams, // Decode provided Huffman table HUF_free_dtable(&ctx->literals_dtable); - size_t size = decode_huf_table(streams->src, compressed_size, - &ctx->literals_dtable); + const size_t size = decode_huf_table(streams->src, compressed_size, + &ctx->literals_dtable); streams->src += size; streams->src_len -= size; compressed_size -= size; @@ -873,14 +898,14 @@ static size_t decode_literals_compressed(io_streams_t *streams, // Decode the Huffman table description static size_t decode_huf_table(const u8 *src, size_t src_len, - HUF_dtable *dtable) { + HUF_dtable *const dtable) { if (src_len < 1) { INP_SIZE(); } const u8 *const osrc = src; - u8 header = src[0]; + const u8 header = src[0]; u8 weights[HUF_MAX_SYMBS]; memset(weights, 0, sizeof(weights)); @@ -892,13 +917,16 @@ static size_t decode_huf_table(const u8 *src, size_t src_len, if (header >= 128) { // Direct representation, read the weights out num_symbs = header - 127; - size_t bytes = (num_symbs + 1) / 2; + const size_t bytes = (num_symbs + 1) / 2; if (bytes > src_len) { INP_SIZE(); } for (int i = 0; i < num_symbs; i++) { + // read_bits_LE isn't applicable here because the weights are order + // reversed within each byte + // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#huffman-tree-header if (i % 2 == 0) { weights[i] = src[i / 2] >> 4; } else { @@ -911,7 +939,7 @@ static size_t decode_huf_table(const u8 *src, size_t src_len, } else { // The weights are FSE encoded, decode them before we can construct the // table - size_t size = + const size_t size = fse_decode_hufweights(src, src_len, weights, &num_symbs, header); src += size; src_len -= size; @@ -922,14 +950,16 @@ static size_t decode_huf_table(const u8 *src, size_t src_len, return src - osrc; } -static size_t fse_decode_hufweights(const u8 *src, size_t src_len, u8 *weights, - int *num_symbs, size_t compressed_size) { +static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len, + u8 *const weights, int *const num_symbs, + const size_t compressed_size) { const int MAX_ACCURACY_LOG = 7; FSE_dtable dtable; // Construct the FSE table - size_t read = FSE_decode_header(&dtable, src, src_len, MAX_ACCURACY_LOG); + const size_t read = + FSE_decode_header(&dtable, src, src_len, MAX_ACCURACY_LOG); if (src_len < compressed_size) { INP_SIZE(); @@ -1001,16 +1031,20 @@ static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { /// Offset decoding is simpler so we just need a maximum code value static const u8 SEQ_MAX_CODES[3] = {35, -1, 52}; -static void decompress_sequences(frame_context_t *ctx, const u8 *src, - size_t src_len, sequence_command_t *sequences, - size_t num_sequences); -static sequence_command_t decode_sequence(sequence_state_t *state, - const u8 *src, i64 *offset); -static size_t decode_seq_table(const u8 *src, size_t src_len, FSE_dtable *table, - seq_part_t type, seq_mode_t mode); +static void decompress_sequences(frame_context_t *const ctx, const u8 *src, + size_t src_len, + sequence_command_t *const sequences, + const size_t num_sequences); +static sequence_command_t decode_sequence(sequence_state_t *const state, + const u8 *const src, + i64 *const offset); +static size_t decode_seq_table(const u8 *src, size_t src_len, + FSE_dtable *const table, const seq_part_t type, + const seq_mode_t mode); -static size_t decode_sequences(frame_context_t *ctx, const u8 *src, - size_t src_len, sequence_command_t **sequences) { +static size_t decode_sequences(frame_context_t *const ctx, const u8 *src, + size_t src_len, + sequence_command_t **const sequences) { size_t num_sequences; // Decode the sequence header and allocate space for the output @@ -1050,9 +1084,10 @@ static size_t decode_sequences(frame_context_t *ctx, const u8 *src, } /// Decompress the FSE encoded sequence commands -static void decompress_sequences(frame_context_t *ctx, const u8 *src, - size_t src_len, sequence_command_t *sequences, - size_t num_sequences) { +static void decompress_sequences(frame_context_t *const ctx, const u8 *src, + size_t src_len, + sequence_command_t *const sequences, + const size_t num_sequences) { if (src_len < 1) { INP_SIZE(); } @@ -1064,21 +1099,31 @@ static void decompress_sequences(frame_context_t *ctx, const u8 *src, CORRUPTION(); } - sequence_state_t state; - size_t read; - // Update the tables we have stored in the context - read = decode_seq_table(src, src_len, &ctx->ll_dtable, seq_literal_length, - (compression_modes >> 6) & 3); - src += read; - src_len -= read; - read = decode_seq_table(src, src_len, &ctx->of_dtable, seq_offset, - (compression_modes >> 4) & 3); - src += read; - src_len -= read; - read = decode_seq_table(src, src_len, &ctx->ml_dtable, seq_match_length, - (compression_modes >> 2) & 3); - src += read; - src_len -= read; + { + size_t read; + // Update the tables we have stored in the context + read = decode_seq_table(src, src_len, &ctx->ll_dtable, + seq_literal_length, + (compression_modes >> 6) & 3); + src += read; + src_len -= read; + } + + { + const size_t read = + decode_seq_table(src, src_len, &ctx->of_dtable, seq_offset, + (compression_modes >> 4) & 3); + src += read; + src_len -= read; + } + + { + const size_t read = decode_seq_table(src, src_len, &ctx->ml_dtable, + seq_match_length, + (compression_modes >> 2) & 3); + src += read; + src_len -= read; + } // Check to make sure none of the tables are uninitialized if (!ctx->ll_dtable.symbols || !ctx->of_dtable.symbols || @@ -1086,12 +1131,13 @@ static void decompress_sequences(frame_context_t *ctx, const u8 *src, CORRUPTION(); } - // Now use the context's tables + sequence_state_t state; + // Copy the context's tables into the local state memcpy(&state.ll_table, &ctx->ll_dtable, sizeof(FSE_dtable)); memcpy(&state.of_table, &ctx->of_dtable, sizeof(FSE_dtable)); memcpy(&state.ml_table, &ctx->ml_dtable, sizeof(FSE_dtable)); - int padding = 8 - log2inf(src[src_len - 1]); + const int padding = 8 - log2inf(src[src_len - 1]); i64 offset = src_len * 8 - padding; FSE_init_state(&state.ll_table, &state.ll_state, src, &offset); @@ -1111,12 +1157,13 @@ static void decompress_sequences(frame_context_t *ctx, const u8 *src, } // Decode a single sequence and update the state -static sequence_command_t decode_sequence(sequence_state_t *state, - const u8 *src, i64 *offset) { +static sequence_command_t decode_sequence(sequence_state_t *const state, + const u8 *const src, + i64 *const offset) { // Decode symbols, but don't update states - u8 of_code = FSE_peek_symbol(&state->of_table, state->of_state); - u8 ll_code = FSE_peek_symbol(&state->ll_table, state->ll_state); - u8 ml_code = FSE_peek_symbol(&state->ml_table, state->ml_state); + const u8 of_code = FSE_peek_symbol(&state->of_table, state->of_state); + const u8 ll_code = FSE_peek_symbol(&state->ll_table, state->ll_state); + const u8 ml_code = FSE_peek_symbol(&state->ml_table, state->ml_state); // Offset doesn't need a max value as it's not decoded using a table if (ll_code > SEQ_MAX_CODES[seq_literal_length] || @@ -1147,9 +1194,9 @@ static sequence_command_t decode_sequence(sequence_state_t *state, } /// Given a sequence part and table mode, decode the FSE distribution -static size_t decode_seq_table(const u8 *src, size_t src_len, FSE_dtable *table, - seq_part_t type, seq_mode_t mode) { - +static size_t decode_seq_table(const u8 *src, size_t src_len, + FSE_dtable *const table, const seq_part_t type, + const seq_mode_t mode) { // Constant arrays indexed by seq_part_t const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, SEQ_OFFSET_DEFAULT_DIST, @@ -1178,7 +1225,7 @@ static size_t decode_seq_table(const u8 *src, size_t src_len, FSE_dtable *table, if (src_len < 1) { INP_SIZE(); } - u8 symb = src[0]; + const u8 symb = src[0]; src++; src_len--; FSE_init_dtable_rle(table, symb); @@ -1204,15 +1251,17 @@ static size_t decode_seq_table(const u8 *src, size_t src_len, FSE_dtable *table, /******* END SEQUENCE DECODING ************************************************/ /******* SEQUENCE EXECUTION ***************************************************/ -static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, - sequence_command_t *sequences, - size_t num_sequences, const u8 *literals, - size_t literals_len) { - u64 *offset_hist = ctx->previous_offsets; +static void execute_sequences(io_streams_t *const streams, + frame_context_t *const ctx, + const sequence_command_t *const sequences, + const size_t num_sequences, + const u8 *literals, + size_t literals_len) { + u64 *const offset_hist = ctx->previous_offsets; size_t total_output = ctx->current_total_output; for (size_t i = 0; i < num_sequences; i++) { - sequence_command_t seq = sequences[i]; + const sequence_command_t seq = sequences[i]; if (seq.literal_length > literals_len) { CORRUPTION(); @@ -1312,46 +1361,48 @@ static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, /******* END SEQUENCE EXECUTION ***********************************************/ /******* OUTPUT SIZE COUNTING *************************************************/ -size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len); +size_t traverse_frame(const frame_header_t *const header, const u8 *src, + size_t src_len); /// Get the decompressed size of an input stream so memory can be allocated in /// advance. /// This is more complex than the implementation in the reference /// implementation, as this API allows for the decompression of multiple /// concatenated frames. -size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) { +size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { const u8 *ip = (const u8 *) src; + size_t ip_len = src_len; size_t dst_size = 0; // Each frame header only gives us the size of its frame, so iterate over all // frames - while (src_len > 0) { - if (src_len < 4) { + while (ip_len > 0) { + if (ip_len < 4) { INP_SIZE(); } - u32 magic_number = read_bits_LE(ip, 32, 0); + const u32 magic_number = read_bits_LE(ip, 32, 0); ip += 4; - src_len -= 4; + ip_len -= 4; if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) { // skippable frame, this has no impact on output size - if (src_len < 4) { + if (ip_len < 4) { INP_SIZE(); } - size_t frame_size = read_bits_LE(ip, 32, 32); + const size_t frame_size = read_bits_LE(ip, 32, 32); - if (src_len < 4 + frame_size) { + if (ip_len < 4 + frame_size) { INP_SIZE(); } // skip over frame ip += 4 + frame_size; - src_len -= 4 + frame_size; + ip_len -= 4 + frame_size; } else if (magic_number == 0xFD2FB528U) { // ZSTD frame frame_header_t header; - parse_frame_header(&header, ip, src_len); + parse_frame_header(&header, ip, ip_len); if (header.frame_content_size == 0 && !header.single_segment_flag) { // Content size not provided, we can't tell @@ -1361,9 +1412,9 @@ size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) { dst_size += header.frame_content_size; // we need to traverse the frame to find when the next one starts - size_t traversed = traverse_frame(&header, ip, src_len); + const size_t traversed = traverse_frame(&header, ip, ip_len); ip += traversed; - src_len -= traversed; + ip_len -= traversed; } else { // not a real frame ERROR("Invalid magic number"); @@ -1375,7 +1426,8 @@ size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) { /// Iterate over each block in a frame to find the end of it, to get to the /// start of the next frame -size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) { +size_t traverse_frame(const frame_header_t *const header, const u8 *src, + size_t src_len) { const u8 *const src_beg = src; const u8 *const src_end = src + src_len; src += header->header_size; @@ -1389,8 +1441,8 @@ size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) { } // Parse the block header last_block = src[0] & 1; - int block_type = (src[0] >> 1) & 3; - size_t block_len = read_bits_LE(src, 21, 3); + const int block_type = (src[0] >> 1) & 3; + const size_t block_len = read_bits_LE(src, 21, 3); src += 3; switch (block_type) { @@ -1432,16 +1484,16 @@ size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) { /******* END OUTPUT SIZE COUNTING *********************************************/ /******* DICTIONARY PARSING ***************************************************/ -static void init_raw_content_dict(dictionary_t *dict, const u8 *src, - size_t src_len); +static void init_raw_content_dict(dictionary_t *const dict, const u8 *const src, + const size_t src_len); -static void parse_dictionary(dictionary_t *dict, const u8 *src, +static void parse_dictionary(dictionary_t *const dict, const u8 *src, size_t src_len) { memset(dict, 0, sizeof(dictionary_t)); if (src_len < 8) { INP_SIZE(); } - u32 magic_number = read_bits_LE(src, 32, 0); + const u32 magic_number = read_bits_LE(src, 32, 0); if (magic_number != 0xEC30A437) { // raw content dict init_raw_content_dict(dict, src, src_len); @@ -1454,25 +1506,26 @@ static void parse_dictionary(dictionary_t *dict, const u8 *src, // Parse the provided entropy tables in order { - size_t read = decode_huf_table(src, src_len, &dict->literals_dtable); + const size_t read = + decode_huf_table(src, src_len, &dict->literals_dtable); src += read; src_len -= read; } { - size_t read = decode_seq_table(src, src_len, &dict->of_dtable, - seq_offset, seq_fse); + const size_t read = decode_seq_table(src, src_len, &dict->of_dtable, + seq_offset, seq_fse); src += read; src_len -= read; } { - size_t read = decode_seq_table(src, src_len, &dict->ml_dtable, - seq_match_length, seq_fse); + const size_t read = decode_seq_table(src, src_len, &dict->ml_dtable, + seq_match_length, seq_fse); src += read; src_len -= read; } { - size_t read = decode_seq_table(src, src_len, &dict->ll_dtable, - seq_literal_length, seq_fse); + const size_t read = decode_seq_table(src, src_len, &dict->ll_dtable, + seq_literal_length, seq_fse); src += read; src_len -= read; } @@ -1505,8 +1558,8 @@ static void parse_dictionary(dictionary_t *dict, const u8 *src, } /// If parse_dictionary is given a raw content dictionary, it delegates here -static void init_raw_content_dict(dictionary_t *dict, const u8 *src, - size_t src_len) { +static void init_raw_content_dict(dictionary_t *const dict, const u8 *const src, + const size_t src_len) { dict->dictionary_id = 0; // Copy in the content dict->content = malloc(src_len); @@ -1519,7 +1572,7 @@ static void init_raw_content_dict(dictionary_t *dict, const u8 *src, } /// Free an allocated dictionary -static void free_dictionary(dictionary_t *dict) { +static void free_dictionary(dictionary_t *const dict) { HUF_free_dtable(&dict->literals_dtable); FSE_free_dtable(&dict->ll_dtable); FSE_free_dtable(&dict->of_dtable); @@ -1531,179 +1584,50 @@ static void free_dictionary(dictionary_t *dict) { } /******* END DICTIONARY PARSING ***********************************************/ -/******* CIRCULAR BUFFER ******************************************************/ -static void cbuf_init(cbuf_t *buf, size_t size) { - buf->ptr = malloc(size); - - if (!buf->ptr) { - BAD_ALLOC(); - } - - memset(buf->ptr, 0x3f, size); - - buf->size = size; - buf->idx = 0; - buf->last_flush = 0; -} - -static size_t cbuf_write_data(cbuf_t *buf, const u8 *src, size_t src_len) { - if (buf->size == 0 && src_len > 0) { - CORRUPTION(); - } - size_t max_len = buf->size - buf->idx; - size_t len = MIN(src_len, max_len); - - memcpy(buf->ptr + buf->idx, src, len); - - buf->idx += len; - - return len; -} - -static size_t cbuf_write_data_full(cbuf_t *buf, const u8 *src, size_t src_len, - u8 *out, size_t out_len) { - size_t written = 0; - size_t flushed = 0; - while (1) { - written += cbuf_write_data(buf, src + written, src_len - written); - if (written == src_len) { - break; - } else { - flushed += cbuf_flush(buf, out + flushed, out_len - flushed); - } - } - - return flushed; -} - -static size_t cbuf_copy_offset(cbuf_t *buf, size_t offset, size_t len) { - if (buf->size == 0 && len > 0) { - CORRUPTION(); - } - if (offset > buf->size) { - CORRUPTION(); - } - size_t max_len = buf->size - buf->idx; - len = MIN(len, max_len); - - size_t read_off = (buf->idx + buf->size - offset) % buf->size; - - for (size_t i = 0; i < len; i++) { - buf->ptr[buf->idx++] = buf->ptr[read_off++]; - if (read_off == buf->size) { - read_off = 0; - } - } - - return len; -} - -static size_t cbuf_copy_offset_full(cbuf_t *buf, size_t offset, size_t len, - u8 *out, size_t out_len) { - size_t written = 0; - size_t flushed = 0; - while (1) { - written += cbuf_copy_offset(buf, offset, len - written); - if (written == len) { - break; - } else { - flushed += cbuf_flush(buf, out + flushed, out_len - flushed); - } - } - - return flushed; -} - -static size_t cbuf_repeat_byte(cbuf_t *buf, u8 byte, size_t len) { - if (buf->size == 0 && len > 0) { - CORRUPTION(); - } - size_t max_len = buf->size - buf->idx; - len = MIN(len, max_len); - - memset(buf->ptr + buf->idx, byte, len); - - return len; -} - -static size_t cbuf_repeat_byte_full(cbuf_t *buf, u8 byte, size_t len, u8 *out, - size_t out_len) { - size_t written = 0; - size_t flushed = 0; - while (1) { - written += cbuf_repeat_byte(buf, byte, len - written); - if (written == len) { - break; - } else { - flushed += cbuf_flush(buf, out + flushed, out_len - flushed); - } - } - - return flushed; -} - -static size_t cbuf_flush(cbuf_t *buf, u8 *dst, size_t dst_len) { - if (buf->idx < buf->last_flush) { - CORRUPTION(); - } - - size_t len = buf->idx - buf->last_flush; - - if (dst && len > dst_len) { - OUT_SIZE(); - } - - // allow for NULL buffers to indicate flushing to nowhere - if (dst) { - memcpy(dst, buf->ptr + buf->last_flush, len); - } - - // we could have a 0 size buffer - if (buf->size) { - buf->idx = buf->idx % buf->size; - } - buf->last_flush = buf->idx; - - return len; -} - -static void cbuf_free(cbuf_t *buf) { - free(buf->ptr); - memset(buf, 0, sizeof(cbuf_t)); -} -/******* END CIRCULAR BUFFER **************************************************/ - /******* BITSTREAM OPERATIONS *************************************************/ -static inline u64 read_bits_LE(const u8 *src, int num, size_t offset) { +/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits +static inline u64 read_bits_LE(const u8 *src, const int num, + const size_t offset) { if (num > 64) { return -1; } + // Skip over bytes that aren't in range src += offset / 8; - offset %= 8; + size_t bit_offset = offset % 8; u64 res = 0; int shift = 0; int left = num; while (left > 0) { u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1); - res += (((u64)*src++ >> offset) & mask) << shift; - shift += 8 - offset; - left -= 8 - offset; - offset = 0; + // Dead the next byte, shift it to account for the offset, and then mask + // out the top part if we don't need all the bits + res += (((u64)*src++ >> bit_offset) & mask) << shift; + shift += 8 - bit_offset; + left -= 8 - bit_offset; + bit_offset = 0; } return res; } -static inline u64 STREAM_read_bits(const u8 *src, int bits, i64 *offset) { +/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so +/// it updates `offset` to `offset - bits`, and then reads `bits` bits from +/// `src + offset`. If the offset becomes negative, the extra bits at the +/// bottom are filled in with `0` bits instead of reading from before `src`. +static inline u64 STREAM_read_bits(const u8 *const src, const int bits, + i64 *const offset) { *offset = *offset - bits; size_t actual_off = *offset; + size_t actual_bits = bits; + // Don't actually read bits from before the start of src, so if `*offset < + // 0` fix actual_off and actual_bits to reflect the quantity to read if (*offset < 0) { - bits += *offset; + actual_bits += *offset; actual_off = 0; } - u64 res = read_bits_LE(src, bits, actual_off); + u64 res = read_bits_LE(src, actual_bits, actual_off); if (*offset < 0) { // Fill in the bottom "overflowed" bits with 0's @@ -1714,16 +1638,9 @@ static inline u64 STREAM_read_bits(const u8 *src, int bits, i64 *offset) { /******* END BITSTREAM OPERATIONS *********************************************/ /******* BIT COUNTING OPERATIONS **********************************************/ -static inline int log2sup(u64 num) { - for (int i = 0; i < 64; i++) { - if (((u64)1 << i) >= num) { - return i; - } - } - return -1; -} - -static inline int log2inf(u64 num) { +/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to +/// `num`, or `-1` if `num == 0`. +static inline int log2inf(const u64 num) { for (int i = 63; i >= 0; i--) { if (((u64)1 << i) <= num) { return i; @@ -1734,33 +1651,38 @@ static inline int log2inf(u64 num) { /******* END BIT COUNTING OPERATIONS ******************************************/ /******* HUFFMAN PRIMITIVES ***************************************************/ -static inline u8 HUF_decode_symbol(HUF_dtable *dtable, u16 *state, - const u8 *src, i64 *offset) { +static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { // Look up the symbol and number of bits to read const u8 symb = dtable->symbols[*state]; const u8 bits = dtable->num_bits[*state]; const u16 rest = STREAM_read_bits(src, bits, offset); + // Shift `bits` bits out of the state, keeping the low order bits that + // weren't necessary to determine this symbol. Then add in the new bits + // read from the stream. *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1); return symb; } -static inline void HUF_init_state(HUF_dtable *dtable, u16 *state, const u8 *src, - i64 *offset) { - // Read in a full dtable->max_bits to initialize the state +static inline void HUF_init_state(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + // Read in a full `dtable->max_bits` bits to initialize the state const u8 bits = dtable->max_bits; *state = STREAM_read_bits(src, bits, offset); } -static size_t HUF_decompress_1stream(HUF_dtable *dtable, u8 *dst, - size_t dst_len, const u8 *src, +static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst, + const size_t dst_len, const u8 *src, size_t src_len) { - u8 *const dst_max = dst + dst_len; - u8 *const odst = dst; + const u8 *const dst_max = dst + dst_len; + const u8 *const odst = dst; // To maintain similarity with FSE, start from the end // Find the last 1 bit - int padding = 8 - log2inf(src[src_len - 1]); + const int padding = 8 - log2inf(src[src_len - 1]); i64 offset = src_len * 8 - padding; u16 state; @@ -1768,6 +1690,7 @@ static size_t HUF_decompress_1stream(HUF_dtable *dtable, u8 *dst, HUF_init_state(dtable, &state, src, &offset); while (dst < dst_max && offset > -dtable->max_bits) { + // Iterate over the stream, decoding one symbol at a time *dst++ = HUF_decode_symbol(dtable, &state, src, &offset); } // If we stopped before consuming all the input, we didn't have enough space @@ -1775,8 +1698,11 @@ static size_t HUF_decompress_1stream(HUF_dtable *dtable, u8 *dst, OUT_SIZE(); } - // The current state should be the `max_bits` preceding the start as - // everything from `src` onward should be consumed + // When all symbols have been decoded, the final state value shouldn't have + // any data from the stream, so it should have "read" dtable->max_bits from + // before the start of `src` + // Therefore `offset`, the edge to start reading new bits at, should be + // dtable->max_bits before the start of the stream if (offset != -dtable->max_bits) { CORRUPTION(); } @@ -1784,28 +1710,18 @@ static size_t HUF_decompress_1stream(HUF_dtable *dtable, u8 *dst, return dst - odst; } -static size_t HUF_decompress_4stream(HUF_dtable *dtable, u8 *dst, - size_t dst_len, const u8 *src, - size_t src_len) { - // Decode each stream independently for simplicity - // If we wanted to we could decode all 4 at the same time for speed, - // utilizing - // more execution units - - const u8 *src1, *src2, *src3, *src4, *src_end; - u8 *dst1, *dst2, *dst3, *dst4, *dst_end; - - size_t total_out = 0; - +static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst, + const size_t dst_len, const u8 *const src, + const size_t src_len) { if (src_len < 6) { INP_SIZE(); } - src1 = src + 6; - src2 = src1 + read_bits_LE(src, 16, 0); - src3 = src2 + read_bits_LE(src, 16, 16); - src4 = src3 + read_bits_LE(src, 16, 32); - src_end = src + src_len; + const u8 *const src1 = src + 6; + const u8 *const src2 = src1 + read_bits_LE(src, 16, 0); + const u8 *const src3 = src2 + read_bits_LE(src, 16, 16); + const u8 *const src4 = src3 + read_bits_LE(src, 16, 32); + const u8 *const src_end = src + src_len; // We can't test with all 4 sizes because the 4th size is a function of the // other 3 and the provided length @@ -1813,26 +1729,32 @@ static size_t HUF_decompress_4stream(HUF_dtable *dtable, u8 *dst, INP_SIZE(); } - size_t segment_size = (dst_len + 3) / 4; - dst1 = dst; - dst2 = dst1 + segment_size; - dst3 = dst2 + segment_size; - dst4 = dst3 + segment_size; - dst_end = dst + dst_len; + const size_t segment_size = (dst_len + 3) / 4; + u8 *const dst1 = dst; + u8 *const dst2 = dst1 + segment_size; + u8 *const dst3 = dst2 + segment_size; + u8 *const dst4 = dst3 + segment_size; + u8 *const dst_end = dst + dst_len; - total_out += - HUF_decompress_1stream(dtable, dst1, segment_size, src1, src2 - src1); - total_out += - HUF_decompress_1stream(dtable, dst2, segment_size, src2, src3 - src2); - total_out += - HUF_decompress_1stream(dtable, dst3, segment_size, src3, src4 - src3); + size_t total_out = 0; + + // Decode each stream independently for simplicity + // If we wanted to we could decode all 4 at the same time for speed, + // utilizing more execution units + total_out += HUF_decompress_1stream(dtable, dst1, segment_size, src1, + src2 - src1); + total_out += HUF_decompress_1stream(dtable, dst2, segment_size, src2, + src3 - src2); + total_out += HUF_decompress_1stream(dtable, dst3, segment_size, src3, + src4 - src3); total_out += HUF_decompress_1stream(dtable, dst4, dst_end - dst4, src4, src_end - src4); return total_out; } -static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs) { +static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, + const int num_symbs) { memset(table, 0, sizeof(HUF_dtable)); if (num_symbs > HUF_MAX_SYMBS) { ERROR("Too many symbols for Huffman"); @@ -1852,7 +1774,7 @@ static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs) { rank_count[bits[i]]++; } - size_t table_size = 1 << max_bits; + const size_t table_size = 1 << max_bits; table->max_bits = max_bits; table->symbols = malloc(table_size); table->num_bits = malloc(table_size); @@ -1881,6 +1803,9 @@ static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs) { if (bits[i] != 0) { // Allocate a code for this symbol and set its range in the table const u16 code = rank_idx[bits[i]]; + // Since the code doesn't care about the bottom `max_bits - bits[i]` + // bits of state, it gets a range that spans all possible values of + // the lower bits const u16 len = 1 << (max_bits - bits[i]); memset(&table->symbols[code], i, len); rank_idx[bits[i]] += len; @@ -1888,8 +1813,9 @@ static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs) { } } -static void HUF_init_dtable_usingweights(HUF_dtable *table, u8 *weights, - int num_symbs) { +static void HUF_init_dtable_usingweights(HUF_dtable *const table, + const u8 *const weights, + const int num_symbs) { // +1 because the last weight is not transmitted in the header if (num_symbs + 1 > HUF_MAX_SYMBS) { ERROR("Too many symbols for Huffman"); @@ -1903,37 +1829,40 @@ static void HUF_init_dtable_usingweights(HUF_dtable *table, u8 *weights, } // Find the first power of 2 larger than the sum - int max_bits = log2inf(weight_sum) + 1; - u64 left_over = ((u64)1 << max_bits) - weight_sum; + const int max_bits = log2inf(weight_sum) + 1; + const u64 left_over = ((u64)1 << max_bits) - weight_sum; // If the left over isn't a power of 2, the weights are invalid if (left_over & (left_over - 1)) { CORRUPTION(); } - int last_weight = log2inf(left_over) + 1; + // left_over is used to find the last weight as it's not transmitted + // by inverting 2^(weight - 1) we can determine the value of last_weight + const int last_weight = log2inf(left_over) + 1; for (int i = 0; i < num_symbs; i++) { bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0; } bits[num_symbs] = - max_bits + 1 - last_weight; // last weight is always non-zero + max_bits + 1 - last_weight; // Last weight is always non-zero HUF_init_dtable(table, bits, num_symbs + 1); } -static void HUF_free_dtable(HUF_dtable *dtable) { +static void HUF_free_dtable(HUF_dtable *const dtable) { free(dtable->symbols); free(dtable->num_bits); memset(dtable, 0, sizeof(HUF_dtable)); } -static void HUF_copy_dtable(HUF_dtable *dst, const HUF_dtable *src) { +static void HUF_copy_dtable(HUF_dtable *const dst, + const HUF_dtable *const src) { if (src->max_bits == 0) { memset(dst, 0, sizeof(HUF_dtable)); return; } - size_t size = (size_t)1 << src->max_bits; + const size_t size = (size_t)1 << src->max_bits; dst->max_bits = src->max_bits; dst->symbols = malloc(size); @@ -1948,46 +1877,56 @@ static void HUF_copy_dtable(HUF_dtable *dst, const HUF_dtable *src) { /******* END HUFFMAN PRIMITIVES ***********************************************/ /******* FSE PRIMITIVES *******************************************************/ -static inline u8 FSE_peek_symbol(FSE_dtable *dtable, u16 state) { +/// Allow a symbol to be decoded without updating state +static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, + const u16 state) { return dtable->symbols[state]; } -static inline void FSE_update_state(FSE_dtable *dtable, u16 *state, - const u8 *src, i64 *offset) { +/// Consumes bits from the input and uses the current state to determine the +/// next state +static inline void FSE_update_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { const u8 bits = dtable->num_bits[*state]; const u16 rest = STREAM_read_bits(src, bits, offset); *state = dtable->new_state_base[*state] + rest; } -// Decodes a single FSE symbol and updates the offset -static inline u8 FSE_decode_symbol(FSE_dtable *dtable, u16 *state, - const u8 *src, i64 *offset) { +/// Decodes a single FSE symbol and updates the offset +static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { const u8 symb = FSE_peek_symbol(dtable, *state); FSE_update_state(dtable, state, src, offset); return symb; } -static inline void FSE_init_state(FSE_dtable *dtable, u16 *state, const u8 *src, - i64 *offset) { +static inline void FSE_init_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + // Read in a full `accuracy_log` bits to initialize the state const u8 bits = dtable->accuracy_log; *state = STREAM_read_bits(src, bits, offset); } -static size_t FSE_decompress_interleaved2(FSE_dtable *dtable, u8 *dst, - size_t dst_len, const u8 *src, - size_t src_len) { +static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, + u8 *dst, const size_t dst_len, + const u8 *const src, + const size_t src_len) { if (src_len == 0) { INP_SIZE(); } - u8 *dst_max = dst + dst_len; - u8 *const odst = dst; + const u8 *const dst_max = dst + dst_len; + const u8 *const odst = dst; // Find the last 1 bit - int padding = 8 - log2inf(src[src_len - 1]); + const int padding = 8 - log2inf(src[src_len - 1]); i64 offset = src_len * 8 - padding; + // The end of the stream contains the 2 states, in this order u16 state1, state2; FSE_init_state(dtable, &state1, src, &offset); FSE_init_state(dtable, &state2, src, &offset); @@ -2002,7 +1941,7 @@ static size_t FSE_decompress_interleaved2(FSE_dtable *dtable, u8 *dst, *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset); if (offset < 0) { // There's still a symbol to decode in state2 - *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset); + *dst++ = FSE_peek_symbol(dtable, state2); break; } @@ -2012,17 +1951,18 @@ static size_t FSE_decompress_interleaved2(FSE_dtable *dtable, u8 *dst, *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset); if (offset < 0) { // There's still a symbol to decode in state1 - *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset); + *dst++ = FSE_peek_symbol(dtable, state1); break; } } - // number of symbols read + // Number of symbols read return dst - odst; } -static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, - int num_symbs, int accuracy_log) { +static void FSE_init_dtable(FSE_dtable *const dtable, + const i16 *const norm_freqs, const int num_symbs, + const int accuracy_log) { if (accuracy_log > FSE_MAX_ACCURACY_LOG) { ERROR("FSE accuracy too large"); } @@ -2032,7 +1972,7 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, dtable->accuracy_log = accuracy_log; - size_t size = (size_t)1 << accuracy_log; + const size_t size = (size_t)1 << accuracy_log; dtable->symbols = malloc(size * sizeof(u8)); dtable->num_bits = malloc(size * sizeof(u8)); dtable->new_state_base = malloc(size * sizeof(u16)); @@ -2057,8 +1997,8 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, } // Place the rest in the table - u16 step = (size >> 1) + (size >> 3) + 3; - u16 mask = size - 1; + const u16 step = (size >> 1) + (size >> 3) + 3; + const u16 mask = size - 1; u16 pos = 0; for (int s = 0; s < num_symbs; s++) { if (norm_freqs[s] <= 0) { @@ -2068,6 +2008,7 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, state_desc[s] = norm_freqs[s]; for (int i = 0; i < norm_freqs[s]; i++) { + // Give `norm_freqs[s]` states to symbol s dtable->symbols[pos] = s; do { pos = (pos + step) & mask; @@ -2087,18 +2028,21 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, for (int i = 0; i < size; i++) { u8 symbol = dtable->symbols[i]; u16 next_state_desc = state_desc[symbol]++; - // Fills in the table appropriately next_state_desc increases by symbol + // Fills in the table appropriately, next_state_desc increases by symbol // over time, decreasing number of bits dtable->num_bits[i] = (u8)(accuracy_log - log2inf(next_state_desc)); - // baseline increases until the bit threshold is passed, at which point + // Baseline increases until the bit threshold is passed, at which point // it resets to 0 dtable->new_state_base[i] = ((u16)next_state_desc << dtable->num_bits[i]) - size; } } -static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, - size_t src_len, int max_accuracy_log) { +/// Decode an FSE header as defined in the Zstandard format specification and +/// use the decoded frequencies to initialize a decoding table. +static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, + const size_t src_len, + const int max_accuracy_log) { if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { ERROR("FSE accuracy too large"); } @@ -2106,7 +2050,7 @@ static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, INP_SIZE(); } - int accuracy_log = 5 + read_bits_LE(src, 4, 0); + const int accuracy_log = 5 + read_bits_LE(src, 4, 0); if (accuracy_log > max_accuracy_log) { ERROR("FSE accuracy too large"); } @@ -2116,17 +2060,19 @@ static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, i16 frequencies[FSE_MAX_SYMBS]; int symb = 0; + // Offset of 4 because 4 bits were already read in for accuracy size_t offset = 4; while (remaining > 1 && symb < FSE_MAX_SYMBS) { - int bits = log2sup(remaining + - 1); // the number of possible values we could read + // Log of the number of possible values we could read + int bits = log2inf(remaining) + 1; + u16 val = read_bits_LE(src, bits, offset); offset += bits; - // try to mask out the lower bits to see if it qualifies for the "small + // Try to mask out the lower bits to see if it qualifies for the "small // value" threshold - u16 lower_mask = ((u16)1 << (bits - 1)) - 1; - u16 threshold = ((u16)1 << bits) - 1 - remaining; + const u16 lower_mask = ((u16)1 << (bits - 1)) - 1; + const u16 threshold = ((u16)1 << bits) - 1 - remaining; if ((val & lower_mask) < threshold) { offset--; @@ -2135,8 +2081,8 @@ static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, val = val - threshold; } - i16 proba = (i16)val - 1; - // a value of -1 is possible, and has special meaning + const i16 proba = (i16)val - 1; + // A value of -1 is possible, and has special meaning remaining -= proba < 0 ? -proba : proba; frequencies[symb] = proba; @@ -2144,7 +2090,7 @@ static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, // Handle the special probability = 0 case if (proba == 0) { - // read the next two bits to see how many more 0s + // Read the next two bits to see how many more 0s int repeat = read_bits_LE(src, 2, offset); offset += 2; @@ -2172,7 +2118,7 @@ static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, return (offset + 7) / 8; } -static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb) { +static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) { dtable->symbols = malloc(sizeof(u8)); dtable->num_bits = malloc(sizeof(u8)); dtable->new_state_base = malloc(sizeof(u16)); @@ -2189,14 +2135,14 @@ static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb) { dtable->accuracy_log = 0; } -static void FSE_free_dtable(FSE_dtable *dtable) { +static void FSE_free_dtable(FSE_dtable *const dtable) { free(dtable->symbols); free(dtable->num_bits); free(dtable->new_state_base); memset(dtable, 0, sizeof(FSE_dtable)); } -static void FSE_copy_dtable(FSE_dtable *dst, const FSE_dtable *src) { +static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) { if (src->accuracy_log == 0) { memset(dst, 0, sizeof(FSE_dtable)); return; diff --git a/contrib/educational_decoder/zstd_decompress.h b/contrib/educational_decoder/zstd_decompress.h index 6e1736720..16f4da3eb 100644 --- a/contrib/educational_decoder/zstd_decompress.h +++ b/contrib/educational_decoder/zstd_decompress.h @@ -7,10 +7,10 @@ * of patent rights can be found in the PATENTS file in the same directory. */ -size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, - size_t src_len); -size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, - size_t src_len, const void *dict, - size_t dict_len); -size_t ZSTD_get_decompressed_size(const void *src, size_t src_len); +size_t ZSTD_decompress(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len); +size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len, + const void *const dict, const size_t dict_len); +size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len);