1
0
mirror of https://git.libssh.org/projects/libssh.git synced 2025-08-08 19:02:06 +03:00

packet: Refactor ssh_packet_socket_callback().

Make error checking more readable and add additional NULL checks.
This commit is contained in:
Andreas Schneider
2013-11-09 13:10:41 +01:00
parent 5581645500
commit cda641176d

View File

@@ -141,36 +141,44 @@ static ssh_packet_callback default_packet_handlers[]= {
* @len length of data received. It might not be enough for a complete packet * @len length of data received. It might not be enough for a complete packet
* @returns number of bytes read and processed. * @returns number of bytes read and processed.
*/ */
int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user){ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
ssh_session session=(ssh_session) user; {
ssh_session session= (ssh_session) user;
unsigned int blocksize = (session->current_crypto ? unsigned int blocksize = (session->current_crypto ?
session->current_crypto->in_cipher->blocksize : 8); session->current_crypto->in_cipher->blocksize : 8);
int current_macsize = session->current_crypto ? MACSIZE : 0; int current_macsize = session->current_crypto ? MACSIZE : 0;
unsigned char mac[30] = {0}; unsigned char mac[30] = {0};
char buffer[16] = {0}; char buffer[16] = {0};
const void *packet = NULL; const uint8_t *packet;
int to_be_read; int to_be_read;
int rc; int rc;
uint32_t len, compsize, payloadsize; uint32_t len, compsize, payloadsize;
uint8_t padding; uint8_t padding;
size_t processed=0; /* number of byte processed from the callback */ size_t processed = 0; /* number of byte processed from the callback */
if (data == NULL) { if (data == NULL) {
goto error; goto error;
} }
if (session->session_state == SSH_SESSION_STATE_ERROR) if (session->session_state == SSH_SESSION_STATE_ERROR) {
goto error; goto error;
}
switch(session->packet_state) { switch(session->packet_state) {
case PACKET_STATE_INIT: case PACKET_STATE_INIT:
if(receivedlen < blocksize){ if (receivedlen < blocksize) {
/* We didn't receive enough data to read at least one block size, give up */ /*
* We didn't receive enough data to read at least one
* block size, give up
*/
return 0; return 0;
} }
memset(&session->in_packet, 0, sizeof(PACKET)); memset(&session->in_packet, 0, sizeof(PACKET));
if (session->in_buffer) { if (session->in_buffer) {
if (buffer_reinit(session->in_buffer) < 0) { rc = buffer_reinit(session->in_buffer);
if (rc < 0) {
goto error; goto error;
} }
} else { } else {
@@ -180,29 +188,34 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
} }
} }
memcpy(buffer,data,blocksize); memcpy(buffer, data, blocksize);
processed += blocksize; processed += blocksize;
len = packet_decrypt_len(session, buffer); len = packet_decrypt_len(session, buffer);
if (buffer_add_data(session->in_buffer, buffer, blocksize) < 0) { rc = buffer_add_data(session->in_buffer, buffer, blocksize);
if (rc < 0) {
goto error; goto error;
} }
if(len > MAX_PACKET_LEN) { if (len > MAX_PACKET_LEN) {
ssh_set_error(session, SSH_FATAL, ssh_set_error(session,
"read_packet(): Packet len too high(%u %.4x)", len, len); SSH_FATAL,
"read_packet(): Packet len too high(%u %.4x)",
len, len);
goto error; goto error;
} }
to_be_read = len - blocksize + sizeof(uint32_t); to_be_read = len - blocksize + sizeof(uint32_t);
if (to_be_read < 0) { if (to_be_read < 0) {
/* remote sshd sends invalid sizes? */ /* remote sshd sends invalid sizes? */
ssh_set_error(session, SSH_FATAL, ssh_set_error(session,
"given numbers of bytes left to be read < 0 (%d)!", to_be_read); SSH_FATAL,
"Given numbers of bytes left to be read < 0 (%d)!",
to_be_read);
goto error; goto error;
} }
/* saves the status of the current operations */ /* Saves the status of the current operations */
session->in_packet.len = len; session->in_packet.len = len;
session->packet_state = PACKET_STATE_SIZEREAD; session->packet_state = PACKET_STATE_SIZEREAD;
/* FALL TROUGH */ /* FALL TROUGH */
@@ -211,17 +224,26 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
to_be_read = len - blocksize + sizeof(uint32_t) + current_macsize; to_be_read = len - blocksize + sizeof(uint32_t) + current_macsize;
/* if to_be_read is zero, the whole packet was blocksize bytes. */ /* if to_be_read is zero, the whole packet was blocksize bytes. */
if (to_be_read != 0) { if (to_be_read != 0) {
if(receivedlen - processed < (unsigned int)to_be_read){ if (receivedlen - processed < (unsigned int)to_be_read) {
/* give up, not enough data in buffer */ /* give up, not enough data in buffer */
SSH_LOG(SSH_LOG_PACKET,"packet: partial packet (read len) [len=%d]",len); SSH_LOG(SSH_LOG_PACKET,"packet: partial packet (read len) [len=%d]",len);
return processed; return processed;
} }
packet = ((unsigned char *)data) + processed; packet = ((uint8_t*)data) + processed;
// ssh_socket_read(session->socket,packet,to_be_read-current_macsize); if (packet == NULL) {
goto error;
}
#if 0
ssh_socket_read(session->socket,
packet,
to_be_read - current_macsize);
#endif
if (buffer_add_data(session->in_buffer, packet, rc = buffer_add_data(session->in_buffer,
to_be_read - current_macsize) < 0) { packet,
to_be_read - current_macsize);
if (rc < 0) {
goto error; goto error;
} }
processed += to_be_read - current_macsize; processed += to_be_read - current_macsize;
@@ -229,19 +251,26 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
if (session->current_crypto) { if (session->current_crypto) {
/* /*
* decrypt the rest of the packet (blocksize bytes already * Decrypt the rest of the packet (blocksize bytes already
* have been decrypted) * have been decrypted)
*/ */
if (packet_decrypt(session, rc = packet_decrypt(session,
((uint8_t*)buffer_get_rest(session->in_buffer) + blocksize), ((uint8_t*)buffer_get_rest(session->in_buffer) + blocksize),
buffer_get_rest_len(session->in_buffer) - blocksize) < 0) { buffer_get_rest_len(session->in_buffer) - blocksize);
if (rc < 0) {
ssh_set_error(session, SSH_FATAL, "Decrypt error"); ssh_set_error(session, SSH_FATAL, "Decrypt error");
goto error; goto error;
} }
/* copy the last part from the incoming buffer */
memcpy(mac,(unsigned char *)packet + to_be_read - current_macsize, MACSIZE);
if (packet_hmac_verify(session, session->in_buffer, mac) < 0) { /* copy the last part from the incoming buffer */
packet = packet + to_be_read - current_macsize;
if (packet == NULL) {
goto error;
}
memcpy(mac, packet, MACSIZE);
rc = packet_hmac_verify(session, session->in_buffer, mac);
if (rc < 0) {
ssh_set_error(session, SSH_FATAL, "HMAC error"); ssh_set_error(session, SSH_FATAL, "HMAC error");
goto error; goto error;
} }
@@ -251,13 +280,17 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
/* skip the size field which has been processed before */ /* skip the size field which has been processed before */
buffer_pass_bytes(session->in_buffer, sizeof(uint32_t)); buffer_pass_bytes(session->in_buffer, sizeof(uint32_t));
if (buffer_get_u8(session->in_buffer, &padding) == 0) { rc = buffer_get_u8(session->in_buffer, &padding);
ssh_set_error(session, SSH_FATAL, "Packet too short to read padding"); if (rc == 0) {
ssh_set_error(session,
SSH_FATAL,
"Packet too short to read padding");
goto error; goto error;
} }
if (padding > buffer_get_rest_len(session->in_buffer)) { if (padding > buffer_get_rest_len(session->in_buffer)) {
ssh_set_error(session, SSH_FATAL, ssh_set_error(session,
SSH_FATAL,
"Invalid padding: %d (%d left)", "Invalid padding: %d (%d left)",
padding, padding,
buffer_get_rest_len(session->in_buffer)); buffer_get_rest_len(session->in_buffer));
@@ -269,29 +302,40 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
#ifdef WITH_ZLIB #ifdef WITH_ZLIB
if (session->current_crypto if (session->current_crypto
&& session->current_crypto->do_compress_in && session->current_crypto->do_compress_in
&& buffer_get_rest_len(session->in_buffer)) { && buffer_get_rest_len(session->in_buffer) > 0) {
if (decompress_buffer(session, session->in_buffer,MAX_PACKET_LEN) < 0) { rc = decompress_buffer(session, session->in_buffer,MAX_PACKET_LEN);
if (rc < 0) {
goto error; goto error;
} }
} }
#endif /* WITH_ZLIB */ #endif /* WITH_ZLIB */
payloadsize=buffer_get_rest_len(session->in_buffer); payloadsize = buffer_get_rest_len(session->in_buffer);
session->recv_seq++; session->recv_seq++;
/* We don't want to rewrite a new packet while still executing the packet callbacks */
/*
* We don't want to rewrite a new packet while still executing the
* packet callbacks
*/
session->packet_state = PACKET_STATE_PROCESSING; session->packet_state = PACKET_STATE_PROCESSING;
ssh_packet_parse_type(session); ssh_packet_parse_type(session);
SSH_LOG(SSH_LOG_PACKET, SSH_LOG(SSH_LOG_PACKET,
"packet: read type %hhd [len=%d,padding=%hhd,comp=%d,payload=%d]", "packet: read type %hhd [len=%d,padding=%hhd,comp=%d,payload=%d]",
session->in_packet.type, len, padding, compsize, payloadsize); session->in_packet.type, len, padding, compsize, payloadsize);
/* execute callbacks */
/* Execute callbacks */
ssh_packet_process(session, session->in_packet.type); ssh_packet_process(session, session->in_packet.type);
session->packet_state = PACKET_STATE_INIT; session->packet_state = PACKET_STATE_INIT;
if(processed < receivedlen){ if (processed < receivedlen) {
/* Handle a potential packet left in socket buffer */ /* Handle a potential packet left in socket buffer */
SSH_LOG(SSH_LOG_PACKET,"Processing %" PRIdS " bytes left in socket buffer", SSH_LOG(SSH_LOG_PACKET,
"Processing %" PRIdS " bytes left in socket buffer",
receivedlen-processed); receivedlen-processed);
rc = ssh_packet_socket_callback(((unsigned char *)data) + processed,
receivedlen - processed,user); packet = ((uint8_t*)data) + processed;
if (packet == NULL) {
goto error;
}
rc = ssh_packet_socket_callback(packet, receivedlen - processed,user);
processed += rc; processed += rc;
} }
@@ -301,7 +345,8 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
return 0; return 0;
} }
ssh_set_error(session, SSH_FATAL, ssh_set_error(session,
SSH_FATAL,
"Invalid state into packet_read2(): %d", "Invalid state into packet_read2(): %d",
session->packet_state); session->packet_state);