diff --git a/tests/client/torture_session.c b/tests/client/torture_session.c index d44c888b..f25406d7 100644 --- a/tests/client/torture_session.c +++ b/tests/client/torture_session.c @@ -328,7 +328,7 @@ static void torture_freed_channel_poll_timeout(void **state) struct torture_state *s = *state; ssh_session session = s->ssh.session; ssh_channel channel; - + bool channel_freed = false; char request[256]; char buff[256] = {0}; int rc; @@ -351,10 +351,19 @@ static void torture_freed_channel_poll_timeout(void **state) } while(rc > 0); assert_ssh_return_code(session, rc); + /* when either of these conditions is met the call to ssh_channel_free will + * actually free the channel so calling poll on that channel will be + * use-after-free */ + if ((channel->flags & SSH_CHANNEL_FLAG_CLOSED_REMOTE) || + (channel->flags & SSH_CHANNEL_FLAG_NOT_BOUND)) { + channel_freed = true; + } ssh_channel_free(channel); - rc = ssh_channel_poll_timeout(channel, 500, 0); - assert_int_equal(rc, SSH_ERROR); + if (!channel_freed) { + rc = ssh_channel_poll_timeout(channel, 500, 0); + assert_int_equal(rc, SSH_ERROR); + } } /* Ensure that calling 'ssh_channel_read_nonblocking' on a freed channel does @@ -395,7 +404,7 @@ static void torture_freed_channel_get_exit_status(void **state) struct torture_state *s = *state; ssh_session session = s->ssh.session; ssh_channel channel; - + bool channel_freed = false; char request[256]; char buff[256] = {0}; int rc; @@ -418,10 +427,19 @@ static void torture_freed_channel_get_exit_status(void **state) } while(rc > 0); assert_ssh_return_code(session, rc); + /* when either of these conditions is met the call to ssh_channel_free will + * actually free the channel so calling poll on that channel will be + * use-after-free */ + if ((channel->flags & SSH_CHANNEL_FLAG_CLOSED_REMOTE) || + (channel->flags & SSH_CHANNEL_FLAG_NOT_BOUND)) { + channel_freed = true; + } ssh_channel_free(channel); - rc = ssh_channel_get_exit_status(channel); - assert_ssh_return_code_equal(session, rc, SSH_ERROR); + if (!channel_freed) { + rc = ssh_channel_get_exit_status(channel); + assert_ssh_return_code_equal(session, rc, SSH_ERROR); + } } int torture_run_tests(void) {