diff --git a/src/gzip.c b/src/gzip.c index fccfd05c..cff15518 100644 --- a/src/gzip.c +++ b/src/gzip.c @@ -24,211 +24,237 @@ #include "config.h" -#include #include +#include #include -#include "libssh/priv.h" #include "libssh/buffer.h" #include "libssh/crypto.h" +#include "libssh/priv.h" #include "libssh/session.h" #ifndef BLOCKSIZE #define BLOCKSIZE 4092 #endif -static z_stream *initcompress(ssh_session session, int level) { - z_stream *stream = NULL; - int status; - - stream = calloc(1, sizeof(z_stream)); - if (stream == NULL) { - return NULL; - } - - status = deflateInit(stream, level); - if (status != Z_OK) { - SAFE_FREE(stream); - ssh_set_error(session, SSH_FATAL, - "status %d initialising zlib deflate", status); - return NULL; - } - - return stream; -} - -static ssh_buffer gzip_compress(ssh_session session, ssh_buffer source, int level) +static z_stream * +initcompress(ssh_session session, int level) { - struct ssh_crypto_struct *crypto = NULL; - z_stream *zout = NULL; - void *in_ptr = ssh_buffer_get(source); - uint32_t in_size = ssh_buffer_get_len(source); - ssh_buffer dest = NULL; - unsigned char out_buf[BLOCKSIZE] = {0}; - uint32_t len; - int status; + z_stream *stream = NULL; + int status; - crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_OUT); - if (crypto == NULL) { - return NULL; - } - zout = crypto->compress_out_ctx; - if (zout == NULL) { - zout = crypto->compress_out_ctx = initcompress(session, level); - if (zout == NULL) { - return NULL; + stream = calloc(1, sizeof(z_stream)); + if (stream == NULL) { + return NULL; } - } - dest = ssh_buffer_new(); - if (dest == NULL) { - return NULL; - } - - zout->next_out = out_buf; - zout->next_in = in_ptr; - zout->avail_in = in_size; - do { - zout->avail_out = BLOCKSIZE; - status = deflate(zout, Z_PARTIAL_FLUSH); + status = deflateInit(stream, level); if (status != Z_OK) { - SSH_BUFFER_FREE(dest); - ssh_set_error(session, SSH_FATAL, - "status %d deflating zlib packet", status); - return NULL; + SAFE_FREE(stream); + ssh_set_error(session, + SSH_FATAL, + "status %d initialising zlib deflate", + status); + return NULL; } - len = BLOCKSIZE - zout->avail_out; - if (ssh_buffer_add_data(dest, out_buf, len) < 0) { - SSH_BUFFER_FREE(dest); - return NULL; - } - zout->next_out = out_buf; - } while (zout->avail_out == 0); - return dest; + return stream; } -int compress_buffer(ssh_session session, ssh_buffer buf) { - ssh_buffer dest = NULL; +static ssh_buffer +gzip_compress(ssh_session session, ssh_buffer source, int level) +{ + struct ssh_crypto_struct *crypto = NULL; + z_stream *zout = NULL; + void *in_ptr = ssh_buffer_get(source); + uint32_t in_size = ssh_buffer_get_len(source); + ssh_buffer dest = NULL; + unsigned char out_buf[BLOCKSIZE] = {0}; + uint32_t len; + int status; - dest = gzip_compress(session, buf, session->opts.compressionlevel); - if (dest == NULL) { - return -1; - } + crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_OUT); + if (crypto == NULL) { + return NULL; + } + zout = crypto->compress_out_ctx; + if (zout == NULL) { + zout = crypto->compress_out_ctx = initcompress(session, level); + if (zout == NULL) { + return NULL; + } + } + + dest = ssh_buffer_new(); + if (dest == NULL) { + return NULL; + } + + zout->next_out = out_buf; + zout->next_in = in_ptr; + zout->avail_in = in_size; + do { + zout->avail_out = BLOCKSIZE; + status = deflate(zout, Z_PARTIAL_FLUSH); + if (status != Z_OK) { + SSH_BUFFER_FREE(dest); + ssh_set_error(session, + SSH_FATAL, + "status %d deflating zlib packet", + status); + return NULL; + } + len = BLOCKSIZE - zout->avail_out; + if (ssh_buffer_add_data(dest, out_buf, len) < 0) { + SSH_BUFFER_FREE(dest); + return NULL; + } + zout->next_out = out_buf; + } while (zout->avail_out == 0); + + return dest; +} + +int +compress_buffer(ssh_session session, ssh_buffer buf) +{ + ssh_buffer dest = NULL; + int rv; + + dest = gzip_compress(session, buf, session->opts.compressionlevel); + if (dest == NULL) { + return -1; + } + + if (ssh_buffer_reinit(buf) < 0) { + SSH_BUFFER_FREE(dest); + return -1; + } + + rv = ssh_buffer_add_data(buf, + ssh_buffer_get(dest), + ssh_buffer_get_len(dest)); + if (rv < 0) { + SSH_BUFFER_FREE(dest); + return -1; + } - if (ssh_buffer_reinit(buf) < 0) { SSH_BUFFER_FREE(dest); - return -1; - } - - if (ssh_buffer_add_data(buf, ssh_buffer_get(dest), ssh_buffer_get_len(dest)) < 0) { - SSH_BUFFER_FREE(dest); - return -1; - } - - SSH_BUFFER_FREE(dest); - return 0; + return 0; } /* decompression */ -static z_stream *initdecompress(ssh_session session) { - z_stream *stream = NULL; - int status; - - stream = calloc(1, sizeof(z_stream)); - if (stream == NULL) { - return NULL; - } - - status = inflateInit(stream); - if (status != Z_OK) { - SAFE_FREE(stream); - ssh_set_error(session, SSH_FATAL, - "Status = %d initiating inflate context!", status); - return NULL; - } - - return stream; -} - -static ssh_buffer gzip_decompress(ssh_session session, ssh_buffer source, size_t maxlen) +static z_stream * +initdecompress(ssh_session session) { - struct ssh_crypto_struct *crypto = NULL; - z_stream *zin = NULL; - void *in_ptr = ssh_buffer_get(source); - uint32_t in_size = ssh_buffer_get_len(source); - unsigned char out_buf[BLOCKSIZE] = {0}; - ssh_buffer dest = NULL; - uint32_t len; - int status; + z_stream *stream = NULL; + int status; - crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_IN); - if (crypto == NULL) { - return NULL; - } + stream = calloc(1, sizeof(z_stream)); + if (stream == NULL) { + return NULL; + } - zin = crypto->compress_in_ctx; - if (zin == NULL) { - zin = crypto->compress_in_ctx = initdecompress(session); + status = inflateInit(stream); + if (status != Z_OK) { + SAFE_FREE(stream); + ssh_set_error(session, + SSH_FATAL, + "Status = %d initiating inflate context!", + status); + return NULL; + } + + return stream; +} + +static ssh_buffer +gzip_decompress(ssh_session session, ssh_buffer source, size_t maxlen) +{ + struct ssh_crypto_struct *crypto = NULL; + z_stream *zin = NULL; + void *in_ptr = ssh_buffer_get(source); + uint32_t in_size = ssh_buffer_get_len(source); + unsigned char out_buf[BLOCKSIZE] = {0}; + ssh_buffer dest = NULL; + uint32_t len; + int status; + + crypto = ssh_packet_get_current_crypto(session, SSH_DIRECTION_IN); + if (crypto == NULL) { + return NULL; + } + + zin = crypto->compress_in_ctx; if (zin == NULL) { - return NULL; - } - } - - dest = ssh_buffer_new(); - if (dest == NULL) { - return NULL; - } - - zin->next_out = out_buf; - zin->next_in = in_ptr; - zin->avail_in = in_size; - - do { - zin->avail_out = BLOCKSIZE; - status = inflate(zin, Z_PARTIAL_FLUSH); - if (status != Z_OK && status != Z_BUF_ERROR) { - ssh_set_error(session, SSH_FATAL, - "status %d inflating zlib packet", status); - SSH_BUFFER_FREE(dest); - return NULL; + zin = crypto->compress_in_ctx = initdecompress(session); + if (zin == NULL) { + return NULL; + } } - len = BLOCKSIZE - zin->avail_out; - if (ssh_buffer_add_data(dest,out_buf,len) < 0) { - SSH_BUFFER_FREE(dest); - return NULL; - } - if (ssh_buffer_get_len(dest) > maxlen){ - /* Size of packet exceeded, avoid a denial of service attack */ - SSH_BUFFER_FREE(dest); - return NULL; + dest = ssh_buffer_new(); + if (dest == NULL) { + return NULL; } + zin->next_out = out_buf; - } while (zin->avail_out == 0); + zin->next_in = in_ptr; + zin->avail_in = in_size; - return dest; + do { + zin->avail_out = BLOCKSIZE; + status = inflate(zin, Z_PARTIAL_FLUSH); + if (status != Z_OK && status != Z_BUF_ERROR) { + ssh_set_error(session, + SSH_FATAL, + "status %d inflating zlib packet", + status); + SSH_BUFFER_FREE(dest); + return NULL; + } + + len = BLOCKSIZE - zin->avail_out; + if (ssh_buffer_add_data(dest, out_buf, len) < 0) { + SSH_BUFFER_FREE(dest); + return NULL; + } + if (ssh_buffer_get_len(dest) > maxlen) { + /* Size of packet exceeded, avoid a denial of service attack */ + SSH_BUFFER_FREE(dest); + return NULL; + } + zin->next_out = out_buf; + } while (zin->avail_out == 0); + + return dest; } -int decompress_buffer(ssh_session session,ssh_buffer buf, size_t maxlen){ - ssh_buffer dest = NULL; +int +decompress_buffer(ssh_session session, ssh_buffer buf, size_t maxlen) +{ + ssh_buffer dest = NULL; + int rv; - dest = gzip_decompress(session,buf, maxlen); - if (dest == NULL) { - return -1; - } + dest = gzip_decompress(session, buf, maxlen); + if (dest == NULL) { + return -1; + } + + if (ssh_buffer_reinit(buf) < 0) { + SSH_BUFFER_FREE(dest); + return -1; + } + + rv = ssh_buffer_add_data(buf, + ssh_buffer_get(dest), + ssh_buffer_get_len(dest)); + if (rv < 0) { + SSH_BUFFER_FREE(dest); + return -1; + } - if (ssh_buffer_reinit(buf) < 0) { SSH_BUFFER_FREE(dest); - return -1; - } - - if (ssh_buffer_add_data(buf, ssh_buffer_get(dest), ssh_buffer_get_len(dest)) < 0) { - SSH_BUFFER_FREE(dest); - return -1; - } - - SSH_BUFFER_FREE(dest); - return 0; + return 0; }