diff --git a/tests/unittests/torture_forwarded_tcpip_callback.c b/tests/unittests/torture_forwarded_tcpip_callback.c index c5dea584..036fde8d 100644 --- a/tests/unittests/torture_forwarded_tcpip_callback.c +++ b/tests/unittests/torture_forwarded_tcpip_callback.c @@ -31,6 +31,14 @@ struct server_thread_args { bool should_accept; }; +static bool is_server_ready = false; +static pthread_mutex_t server_mutex = PTHREAD_MUTEX_INITIALIZER; +static pthread_cond_t server_cond = PTHREAD_COND_INITIALIZER; + +static bool client_callbacks_initialised = false; +static pthread_mutex_t client_mutex = PTHREAD_MUTEX_INITIALIZER; +static pthread_cond_t client_cond = PTHREAD_COND_INITIALIZER; + static int setup(void **state) { struct hostkey_state *h = NULL; @@ -43,7 +51,7 @@ static int setup(void **state) return -1; } - h = malloc(sizeof(struct hostkey_state)); + h = (struct hostkey_state *)malloc(sizeof(struct hostkey_state)); assert_non_null(h); h->hostkey_path = strdup("/tmp/libssh_hostkey_XXXXXX"); @@ -62,6 +70,10 @@ static int setup(void **state) *state = h; + /* Reset before every test */ + is_server_ready = false; + client_callbacks_initialised = false; + return 0; } @@ -118,10 +130,16 @@ static void *server_thread(void *arg) server = ssh_new(); assert_non_null(server); - rc = ssh_bind_accept(sshbind, server); + rc = ssh_set_server_callbacks(server, &server_cb); assert_int_equal(rc, SSH_OK); - rc = ssh_set_server_callbacks(server, &server_cb); + /* Signal that the server is ready */ + pthread_mutex_lock(&server_mutex); + is_server_ready = true; + pthread_cond_signal(&server_cond); + pthread_mutex_unlock(&server_mutex); + + rc = ssh_bind_accept(sshbind, server); assert_int_equal(rc, SSH_OK); rc = ssh_handle_key_exchange(server); @@ -145,6 +163,13 @@ static void *server_thread(void *arg) /* Cleanup the event */ ssh_event_free(event); + /* Wait for client callbacks to be initialized before proceeding */ + pthread_mutex_lock(&client_mutex); + while (!client_callbacks_initialised) { + pthread_cond_wait(&client_cond, &client_mutex); + } + pthread_mutex_unlock(&client_mutex); + channel = ssh_channel_new(server); assert_non_null(channel); @@ -231,7 +256,12 @@ static void torture_forwarded_tcpip_callback(void **state, bool should_accept) rc = pthread_create(&server_pthread, NULL, server_thread, &args); assert_return_code(rc, errno); - usleep(200 * 1000); /* Give the server time to start */ + /* Wait for the server to be ready using condition variable */ + pthread_mutex_lock(&server_mutex); + while (!is_server_ready) { + pthread_cond_wait(&server_cond, &server_mutex); + } + pthread_mutex_unlock(&server_mutex); session = torture_ssh_session(NULL, "localhost", &server_port, "foo", "bar"); @@ -246,6 +276,12 @@ static void torture_forwarded_tcpip_callback(void **state, bool should_accept) rc = ssh_event_add_session(event, session); assert_int_equal(rc, SSH_OK); + /* Signal that client callbacks are initialized */ + pthread_mutex_lock(&client_mutex); + client_callbacks_initialised = true; + pthread_cond_signal(&client_cond); + pthread_mutex_unlock(&client_mutex); + event_rc = SSH_OK; while (channel_data.req_seen != 1 && event_rc == SSH_OK) { event_rc = ssh_event_dopoll(event, -1);