diff --git a/src/channels.c b/src/channels.c index 770bd57a..72f44102 100644 --- a/src/channels.c +++ b/src/channels.c @@ -2925,7 +2925,6 @@ error: return rc; } - /** * @brief Read data from a channel into a buffer. * @@ -2939,8 +2938,8 @@ error: * * @param is_stderr A boolean value to mark reading from the stderr stream. * - * @return The number of bytes read, 0 on end of file or SSH_ERROR - * on error. + * @return The number of bytes read, 0 on end of file, SSH_AGAIN on + * timeout and SSH_ERROR on error. * @deprecated Please use ssh_channel_read instead * @warning This function doesn't work in nonblocking/timeout mode * @see ssh_channel_read @@ -3032,8 +3031,6 @@ static int ssh_channel_read_termination(void *s) return 0; } -/* TODO: FIXME Fix the blocking behaviours */ - /** * @brief Reads data from a channel. * @@ -3045,9 +3042,8 @@ static int ssh_channel_read_termination(void *s) * * @param[in] is_stderr A boolean value to mark reading from the stderr flow. * - * @return The number of bytes read, 0 on end of file or SSH_ERROR - * on error. In nonblocking mode it can return 0 if no data - * is available or SSH_AGAIN. + * @return The number of bytes read, 0 on end of file, SSH_AGAIN on + * timeout and SSH_ERROR on error. * * @warning This function may return less than count bytes of data, and won't * block until count bytes have been read. @@ -3075,9 +3071,8 @@ int ssh_channel_read(ssh_channel channel, void *dest, uint32_t count, int is_std * @param[in] timeout_ms A timeout in milliseconds. A value of -1 means * infinite timeout. * - * @return The number of bytes read, 0 on end of file or SSH_ERROR - * on error. In nonblocking mode it Can return 0 if no data - * is available or SSH_AGAIN. + * @return The number of bytes read, 0 on end of file, SSH_AGAIN on + * timeout, SSH_ERROR on error. * * @warning This function may return less than count bytes of data, and won't * block until count bytes have been read. @@ -3133,8 +3128,8 @@ int ssh_channel_read_timeout(ssh_channel channel, timeout_ms, ssh_channel_read_termination, &ctx); - if (rc == SSH_ERROR){ - return rc; + if (rc == SSH_ERROR || rc == SSH_AGAIN) { + return rc; } /* @@ -3188,8 +3183,8 @@ int ssh_channel_read_timeout(ssh_channel channel, * * @param[in] is_stderr A boolean to select the stderr stream. * - * @return The number of bytes read (0 if nothing is available), - * SSH_ERROR on error, and SSH_EOF if the channel is EOF. + * @return The number of bytes read, SSH_AGAIN if nothing is + * available, SSH_ERROR on error, and SSH_EOF if the channel is EOF. * * @see ssh_channel_is_eof() */ diff --git a/src/scp.c b/src/scp.c index 40f45dbc..a1e3687f 100644 --- a/src/scp.c +++ b/src/scp.c @@ -266,7 +266,7 @@ int ssh_scp_close(ssh_scp scp) */ while (!ssh_channel_is_eof(scp->channel)) { rc = ssh_channel_read(scp->channel, buffer, sizeof(buffer), 0); - if (rc == SSH_ERROR || rc == 0) { + if (rc == SSH_ERROR || rc == SSH_AGAIN || rc == 0) { break; } } @@ -603,6 +603,12 @@ int ssh_scp_response(ssh_scp scp, char **response) rc = ssh_channel_read(scp->channel, &code, 1, 0); if (rc == SSH_ERROR) { + scp->state = SSH_SCP_ERROR; + return SSH_ERROR; + } + if (rc == SSH_AGAIN) { + ssh_set_error(scp->session, SSH_FATAL, "SCP: ssh_channel_read timeout"); + scp->state = SSH_SCP_ERROR; return SSH_ERROR; } @@ -760,6 +766,14 @@ int ssh_scp_read_string(ssh_scp scp, char *buffer, size_t len) break; } + if (err == SSH_AGAIN) { + ssh_set_error(scp->session, + SSH_FATAL, + "SCP: ssh_channel_read timeout"); + err = SSH_ERROR; + break; + } + read++; if (buffer[read - 1] == '\n') { break; @@ -1027,12 +1041,16 @@ int ssh_scp_read(ssh_scp scp, void *buffer, size_t size) } rc = ssh_channel_read(scp->channel, buffer, size, 0); - if (rc != SSH_ERROR) { - scp->processed += rc; - } else { + if (rc == SSH_ERROR) { scp->state = SSH_SCP_ERROR; return SSH_ERROR; } + if (rc == SSH_AGAIN) { + ssh_set_error(scp->session, SSH_FATAL, "SCP: ssh_channel_read timeout"); + scp->state = SSH_SCP_ERROR; + return SSH_ERROR; + } + scp->processed += rc; /* Check if we arrived at end of file */ if (scp->processed == scp->filelen) { diff --git a/src/sftp_common.c b/src/sftp_common.c index 005fa9da..b340f9a1 100644 --- a/src/sftp_common.c +++ b/src/sftp_common.c @@ -87,6 +87,12 @@ sftp_packet sftp_packet_read(sftp_session sftp) "Received EOF while reading sftp packet size"); sftp_set_error(sftp, SSH_FX_EOF); goto error; + } else { + ssh_set_error(sftp->session, + SSH_FATAL, + "Timeout while reading sftp packet size"); + sftp_set_error(sftp, SSH_FX_FAILURE); + goto error; } } else { nread += s; @@ -112,6 +118,12 @@ sftp_packet sftp_packet_read(sftp_session sftp) "Received EOF while reading sftp packet type"); sftp_set_error(sftp, SSH_FX_EOF); goto error; + } else { + ssh_set_error(sftp->session, + SSH_FATAL, + "Timeout while reading sftp packet type"); + sftp_set_error(sftp, SSH_FX_FAILURE); + goto error; } } } while (nread < 1); @@ -147,6 +159,12 @@ sftp_packet sftp_packet_read(sftp_session sftp) "Received EOF while reading sftp packet"); sftp_set_error(sftp, SSH_FX_EOF); goto error; + } else { + ssh_set_error(sftp->session, + SSH_FATAL, + "Timeout while reading sftp packet"); + sftp_set_error(sftp, SSH_FX_FAILURE); + goto error; } } } @@ -697,7 +715,8 @@ static void request_queue_free(sftp_request_queue queue) SAFE_FREE(queue); } -static int sftp_enqueue(sftp_session sftp, sftp_message msg) +static int +sftp_enqueue(sftp_session sftp, sftp_message msg) { sftp_request_queue queue = NULL; sftp_request_queue ptr; diff --git a/tests/benchmarks/bench_raw.c b/tests/benchmarks/bench_raw.c index 0fd1446d..05a6ceb0 100644 --- a/tests/benchmarks/bench_raw.c +++ b/tests/benchmarks/bench_raw.c @@ -135,8 +135,13 @@ int benchmarks_raw_up (ssh_session session, struct argument_s *args, snprintf(cmd,sizeof(cmd),"%s /tmp/eater.py", PYTHON_PATH); if(ssh_channel_request_exec(channel,cmd)==SSH_ERROR) goto error; - if((err=ssh_channel_read(channel,buffer,sizeof(buffer)-1,0))==SSH_ERROR) - goto error; + err = ssh_channel_read(channel, buffer, sizeof(buffer) - 1, 0); + if (err == SSH_ERROR) + goto error; + if (err == SSH_AGAIN) { + fprintf(stderr, "ssh_channel_read timeout"); + goto error; + } buffer[err]=0; if(!strstr(buffer,"go")){ fprintf(stderr,"parse error : %s\n",buffer); @@ -160,9 +165,13 @@ int benchmarks_raw_up (ssh_session session, struct argument_s *args, if(args->verbose>0) fprintf(stdout,"Finished upload, now waiting the ack\n"); - - if((err=ssh_channel_read(channel,buffer,5,0))==SSH_ERROR) + err = ssh_channel_read(channel, buffer, 5, 0); + if (err == SSH_ERROR) goto error; + if (err == SSH_AGAIN) { + fprintf(stderr, "ssh_channel_read timeout"); + goto error; + } buffer[err]=0; if(!strstr(buffer,"done")){ fprintf(stderr,"parse error : %s\n",buffer); @@ -272,8 +281,12 @@ int benchmarks_raw_down (ssh_session session, struct argument_s *args, if(toread > args->chunksize) toread = args->chunksize; r=ssh_channel_read(channel,buffer,toread,0); - if(r == SSH_ERROR) - goto error; + if (r == SSH_ERROR) + goto error; + if (r == SSH_AGAIN) { + fprintf(stderr, "ssh_channel_read timeout"); + goto error; + } total += r; } diff --git a/tests/client/torture_sftp_packet_read.c b/tests/client/torture_sftp_packet_read.c index 6cac9a41..4eaba301 100644 --- a/tests/client/torture_sftp_packet_read.c +++ b/tests/client/torture_sftp_packet_read.c @@ -7,31 +7,35 @@ #include "config.h" -#include "torture.h" #include "sftp.c" +#include "torture.h" -#include -#include -#include #include +#include +#include +#include -static int sshd_setup(void **state) +static int +sshd_setup(void **state) { torture_setup_sshd_server(state, false); return 0; } -static int sshd_teardown(void **state) { +static int +sshd_teardown(void **state) +{ torture_teardown_sshd_server(state); return 0; } -static int session_setup(void **state) +static int +session_setup(void **state) { struct torture_state *s = *state; - struct passwd *pwd; + struct passwd *pwd = NULL; int rc; pwd = getpwnam("bob"); @@ -53,7 +57,8 @@ static int session_setup(void **state) return 0; } -static int session_teardown(void **state) +static int +session_teardown(void **state) { struct torture_state *s = *state; @@ -65,31 +70,42 @@ static int session_teardown(void **state) return 0; } -static void torture_sftp_packet_read(void **state) { +static void +torture_sftp_packet_read(void **state) +{ struct torture_state *s = *state; struct torture_sftp *t = s->ssh.tsftp; + sftp_packet packet = NULL; int fds[2]; int rc; - // creating blocking fd is the default pipe behaviour + /* creating blocking fd is the default pipe behaviour */ rc = pipe(fds); - assert_true(rc == 0); + assert_return_code(rc, errno); + t->ssh->opts.timeout = 1; ssh_socket_set_fd(t->ssh->socket, fds[0]); - rc = sftp_packet_read(t->sftp); - assert_true(rc == SSH_AGAIN); + /* + * Making sure that the sftp_packet_read function times out and returns + * NULL. + */ + packet = sftp_packet_read(t->sftp); + assert_null(packet); close(fds[0]); close(fds[1]); } -int torture_run_tests(void) { +int +torture_run_tests(void) +{ int rc; struct CMUnitTest tests[] = { - cmocka_unit_test_setup_teardown(torture_sftp_packet_read, session_setup, - session_teardown), + cmocka_unit_test_setup_teardown(torture_sftp_packet_read, + session_setup, + session_teardown), }; ssh_init();