diff --git a/include/libssh/libssh.h b/include/libssh/libssh.h index 6a034db9..7857a77b 100644 --- a/include/libssh/libssh.h +++ b/include/libssh/libssh.h @@ -841,6 +841,7 @@ LIBSSH_API int ssh_buffer_add_data(ssh_buffer buffer, const void *data, uint32_t LIBSSH_API uint32_t ssh_buffer_get_data(ssh_buffer buffer, void *data, uint32_t requestedlen); LIBSSH_API void *ssh_buffer_get(ssh_buffer buffer); LIBSSH_API uint32_t ssh_buffer_get_len(ssh_buffer buffer); +LIBSSH_API int ssh_session_set_disconnect_message(ssh_session session, const char *message); #ifndef LIBSSH_LEGACY_0_4 #include "libssh/legacy.h" diff --git a/include/libssh/session.h b/include/libssh/session.h index 55a10e48..0a6fb080 100644 --- a/include/libssh/session.h +++ b/include/libssh/session.h @@ -136,6 +136,7 @@ struct ssh_session_struct { the server */ char *discon_msg; /* disconnect message from the remote host */ + char *disconnect_message; /* disconnect message to be set */ ssh_buffer in_buffer; PACKET in_packet; ssh_buffer out_buffer; diff --git a/src/client.c b/src/client.c index 5ed893b5..e958a523 100644 --- a/src/client.c +++ b/src/client.c @@ -690,6 +690,39 @@ int ssh_get_openssh_version(ssh_session session) return session->openssh; } +/** + * @brief Add disconnect message when ssh_session is disconnected + * To add a disconnect message to give peer a better hint. + * @param session The SSH session to use. + * @param message The message to send after the session is disconnected. + * If no message is passed then a default message i.e + * "Bye Bye" will be sent. + */ +int +ssh_session_set_disconnect_message(ssh_session session, const char *message) +{ + if (session == NULL) { + return SSH_ERROR; + } + + if (message == NULL || strlen(message) == 0) { + SAFE_FREE(session->disconnect_message); //To free any message set earlier. + session->disconnect_message = strdup("Bye Bye") ; + if (session->disconnect_message == NULL) { + ssh_set_error_oom(session); + return SSH_ERROR; + } + return SSH_OK; + } + SAFE_FREE(session->disconnect_message); //To free any message set earlier. + session->disconnect_message = strdup(message); + if (session->disconnect_message == NULL) { + ssh_set_error_oom(session); + return SSH_ERROR; + } + return SSH_OK; +} + /** * @brief Disconnect from a session (client or server). @@ -712,12 +745,20 @@ ssh_disconnect(ssh_session session) return; } + if (session->disconnect_message == NULL) { + session->disconnect_message = strdup("Bye Bye") ; + if (session->disconnect_message == NULL) { + ssh_set_error_oom(session); + goto error; + } + } + if (session->socket != NULL && ssh_socket_is_open(session->socket)) { rc = ssh_buffer_pack(session->out_buffer, "bdss", SSH2_MSG_DISCONNECT, SSH2_DISCONNECT_BY_APPLICATION, - "Bye Bye", + session->disconnect_message, ""); /* language tag */ if (rc != SSH_OK) { ssh_set_error_oom(session); @@ -772,6 +813,7 @@ error: session->auth.supported_methods = 0; SAFE_FREE(session->serverbanner); SAFE_FREE(session->clientbanner); + SAFE_FREE(session->disconnect_message); if (session->ssh_message_list) { ssh_message msg = NULL; diff --git a/src/session.c b/src/session.c index 7eacc925..61b9720a 100644 --- a/src/session.c +++ b/src/session.c @@ -299,6 +299,7 @@ void ssh_free(ssh_session session) SAFE_FREE(session->serverbanner); SAFE_FREE(session->clientbanner); SAFE_FREE(session->banner); + SAFE_FREE(session->disconnect_message); SAFE_FREE(session->opts.bindaddr); SAFE_FREE(session->opts.custombanner); diff --git a/tests/server/torture_server.c b/tests/server/torture_server.c index fecf86c9..c21f3612 100644 --- a/tests/server/torture_server.c +++ b/tests/server/torture_server.c @@ -370,6 +370,66 @@ static void torture_server_unknown_global_request(void **state) ssh_channel_close(channel); } +static void torture_server_set_disconnect_message(void **state) +{ + struct test_server_st *tss = *state; + struct torture_state *s = NULL; + ssh_session session; + int rc; + const char *message = "Goodbye"; + + assert_non_null(tss); + + s = tss->state; + assert_non_null(s); + + session = s->ssh.session; + assert_non_null(session); + + rc = ssh_session_set_disconnect_message(session,message); + assert_ssh_return_code(session, rc); + assert_string_equal(session->disconnect_message,message); +} + +static void torture_null_server_set_disconnect_message(void **state) +{ + struct test_server_st *tss = *state; + struct torture_state *s = NULL; + ssh_session session; + int rc; + + assert_non_null(tss); + + s = tss->state; + assert_non_null(s); + + session = s->ssh.session; + assert_non_null(session); + + rc = ssh_session_set_disconnect_message(NULL,"Goodbye"); + assert_int_equal(rc, SSH_ERROR); +} + +static void torture_server_set_null_disconnect_message(void **state) +{ + struct test_server_st *tss = *state; + struct torture_state *s = NULL; + ssh_session session; + int rc; + + assert_non_null(tss); + + s = tss->state; + assert_non_null(s); + + session = s->ssh.session; + assert_non_null(session); + + rc = ssh_session_set_disconnect_message(session,NULL); + assert_int_equal(rc, SSH_OK); + assert_string_equal(session->disconnect_message,"Bye Bye"); +} + int torture_run_tests(void) { int rc; struct CMUnitTest tests[] = { @@ -388,6 +448,15 @@ int torture_run_tests(void) { cmocka_unit_test_setup_teardown(torture_server_unknown_global_request, session_setup, session_teardown), + cmocka_unit_test_setup_teardown(torture_server_set_disconnect_message, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_null_server_set_disconnect_message, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_server_set_null_disconnect_message, + session_setup, + session_teardown), }; ssh_init();