diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index 45997cd2..d8de3ff1 100644 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -101,6 +101,7 @@ if (UNIX AND NOT WIN32) ${LIBSSH_THREAD_UNIT_TESTS} torture_unit_server torture_server_x11 + torture_forwarded_tcpip_callback ) endif (WITH_SERVER) endif (UNIX AND NOT WIN32) diff --git a/tests/unittests/torture_forwarded_tcpip_callback.c b/tests/unittests/torture_forwarded_tcpip_callback.c new file mode 100644 index 00000000..c5dea584 --- /dev/null +++ b/tests/unittests/torture_forwarded_tcpip_callback.c @@ -0,0 +1,299 @@ +#include "config.h" + +#define LIBSSH_STATIC + +#include +#include +#include +#include +#include + +#include "torture.h" +#include "torture_key.h" + +#include + +#define TEST_SERVER_PORT 2222 +#define TEST_DEST_HOST "localhost" +#define TEST_DEST_PORT 12345 +#define TEST_ORIG_HOST "127.0.0.1" +#define TEST_ORIG_PORT 54321 + +struct hostkey_state { + const char *hostkey; + char *hostkey_path; + enum ssh_keytypes_e key_type; + int fd; +}; + +struct server_thread_args { + struct hostkey_state *h; + bool should_accept; +}; + +static int setup(void **state) +{ + struct hostkey_state *h = NULL; + mode_t mask; + int rc; + + ssh_threads_set_callbacks(ssh_threads_get_pthread()); + rc = ssh_init(); + if (rc != SSH_OK) { + return -1; + } + + h = malloc(sizeof(struct hostkey_state)); + assert_non_null(h); + + h->hostkey_path = strdup("/tmp/libssh_hostkey_XXXXXX"); + assert_non_null(h->hostkey_path); + + mask = umask(S_IRWXO | S_IRWXG); + h->fd = mkstemp(h->hostkey_path); + umask(mask); + assert_return_code(h->fd, errno); + close(h->fd); + + h->key_type = SSH_KEYTYPE_ECDSA_P256; + h->hostkey = torture_get_testkey(h->key_type, 0); + + torture_write_file(h->hostkey_path, h->hostkey); + + *state = h; + + return 0; +} + +static int teardown(void **state) +{ + struct hostkey_state *h = (struct hostkey_state *)*state; + + unlink(h->hostkey_path); + free(h->hostkey_path); + free(h); + + ssh_finalize(); + + return 0; +} + +static int auth_password_accept(ssh_session session, + const char *user, + const char *password, + void *userdata) +{ + /* unused */ + (void)session; + (void)user; + (void)password; + (void)userdata; + + return SSH_AUTH_SUCCESS; +} + +static void *server_thread(void *arg) +{ + struct server_thread_args *args = (struct server_thread_args *)arg; + struct hostkey_state *h = args->h; + bool should_accept = args->should_accept; + ssh_bind sshbind = NULL; + ssh_session server = NULL; + ssh_channel channel = NULL; + ssh_event event = NULL; + int rc; + + struct ssh_server_callbacks_struct server_cb = { + .auth_password_function = auth_password_accept, + }; + ssh_callbacks_init(&server_cb); + + /* Create server */ + sshbind = torture_ssh_bind("localhost", + TEST_SERVER_PORT, + h->key_type, + h->hostkey_path); + assert_non_null(sshbind); + + server = ssh_new(); + assert_non_null(server); + + rc = ssh_bind_accept(sshbind, server); + assert_int_equal(rc, SSH_OK); + + rc = ssh_set_server_callbacks(server, &server_cb); + assert_int_equal(rc, SSH_OK); + + rc = ssh_handle_key_exchange(server); + assert_int_equal(rc, SSH_OK); + + /* Handle client connection */ + event = ssh_event_new(); + assert_non_null(event); + + rc = ssh_event_add_session(event, server); + assert_int_equal(rc, SSH_OK); + + /* Poll until authentication is complete */ + while (server->session_state != SSH_SESSION_STATE_AUTHENTICATED) { + rc = ssh_event_dopoll(event, -1); + if (rc == SSH_ERROR) { + break; + } + } + + /* Cleanup the event */ + ssh_event_free(event); + + channel = ssh_channel_new(server); + assert_non_null(channel); + + rc = ssh_channel_open_reverse_forward(channel, + TEST_DEST_HOST, + TEST_DEST_PORT, + TEST_ORIG_HOST, + TEST_ORIG_PORT); + if (should_accept) { + assert_int_equal(rc, SSH_OK); + } else { + assert_int_equal(rc, SSH_ERROR); + } + + ssh_channel_close(channel); + ssh_channel_free(channel); + ssh_bind_free(sshbind); + ssh_free(server); + + return NULL; +} + +struct channel_data { + /* Whether the callback should accept the channel open request */ + bool should_accept; + + int req_seen; + char *dest_host; + uint32_t dest_port; + char *orig_host; + uint32_t orig_port; +}; + +static ssh_channel channel_forwarded_tcpip_callback(ssh_session session, + const char *dest_host, + int dest_port, + const char *orig_host, + int orig_port, + void *userdata) +{ + struct channel_data *channel_data = (struct channel_data *)userdata; + ssh_channel channel = NULL; + + /* Record that we've seen a forwarded-tcpip request and store the parameters + */ + channel_data->req_seen = 1; + channel_data->dest_host = strdup(dest_host); + channel_data->dest_port = dest_port; + channel_data->orig_host = strdup(orig_host); + channel_data->orig_port = orig_port; + + /* Create and return a new channel for this request */ + if (channel_data->should_accept) { + channel = ssh_channel_new(session); + } + + return channel; +} + +static void torture_forwarded_tcpip_callback(void **state, bool should_accept) +{ + int rc, event_rc; + pthread_t server_pthread; + ssh_session session = NULL; + ssh_event event = NULL; + struct channel_data channel_data; + unsigned int server_port = TEST_SERVER_PORT; + + struct server_thread_args args = { + .h = (struct hostkey_state *)*state, + .should_accept = should_accept, + }; + + struct ssh_callbacks_struct client_cb = { + .userdata = &channel_data, + .channel_open_request_forwarded_tcpip_function = + channel_forwarded_tcpip_callback, + }; + ssh_callbacks_init(&client_cb); + + memset(&channel_data, 0, sizeof(channel_data)); + channel_data.should_accept = should_accept; + + rc = pthread_create(&server_pthread, NULL, server_thread, &args); + assert_return_code(rc, errno); + + usleep(200 * 1000); /* Give the server time to start */ + + session = + torture_ssh_session(NULL, "localhost", &server_port, "foo", "bar"); + assert_non_null(session); + + rc = ssh_set_callbacks(session, &client_cb); + assert_int_equal(rc, SSH_OK); + + event = ssh_event_new(); + assert_non_null(event); + + rc = ssh_event_add_session(event, session); + assert_int_equal(rc, SSH_OK); + + event_rc = SSH_OK; + while (channel_data.req_seen != 1 && event_rc == SSH_OK) { + event_rc = ssh_event_dopoll(event, -1); + } + + /* Cleanup */ + ssh_event_free(event); + ssh_free(session); + + rc = pthread_join(server_pthread, NULL); + assert_int_equal(rc, 0); + + /* Verify forwarded-tcpip request parameters */ + assert_true(channel_data.req_seen); + assert_string_equal(channel_data.dest_host, TEST_DEST_HOST); + assert_int_equal(channel_data.dest_port, TEST_DEST_PORT); + assert_string_equal(channel_data.orig_host, TEST_ORIG_HOST); + assert_int_equal(channel_data.orig_port, TEST_ORIG_PORT); + + /* Free allocated memory */ + free(channel_data.dest_host); + free(channel_data.orig_host); +} + +static void torture_forwarded_tcpip_callback_success(void **state) +{ + torture_forwarded_tcpip_callback(state, true); +} + +static void torture_forwarded_tcpip_callback_failure(void **state) +{ + torture_forwarded_tcpip_callback(state, false); +} + +int torture_run_tests(void) +{ + int rc; + const struct CMUnitTest tests[] = { + cmocka_unit_test_setup_teardown( + torture_forwarded_tcpip_callback_success, + setup, + teardown), + cmocka_unit_test_setup_teardown( + torture_forwarded_tcpip_callback_failure, + setup, + teardown), + }; + + rc = cmocka_run_group_tests(tests, NULL, NULL); + return rc; +}