diff --git a/include/libssh/socket.h b/include/libssh/socket.h index ae56f5f6..cdd3c837 100644 --- a/include/libssh/socket.h +++ b/include/libssh/socket.h @@ -35,6 +35,7 @@ void ssh_socket_reset(ssh_socket s); void ssh_socket_free(ssh_socket s); void ssh_socket_set_fd(ssh_socket s, socket_t fd); socket_t ssh_socket_get_fd(ssh_socket s); +void ssh_socket_set_connected(ssh_socket s, struct ssh_poll_handle_struct *p); int ssh_socket_unix(ssh_socket s, const char *path); void ssh_execute_command(const char *command, socket_t in, socket_t out); #ifndef _WIN32 diff --git a/src/bind.c b/src/bind.c index d1dbac94..226e217c 100644 --- a/src/bind.c +++ b/src/bind.c @@ -426,6 +426,7 @@ void ssh_bind_free(ssh_bind sshbind){ int ssh_bind_accept_fd(ssh_bind sshbind, ssh_session session, socket_t fd) { + ssh_poll_handle handle = NULL; int i, rc; if (sshbind == NULL) { @@ -517,7 +518,12 @@ int ssh_bind_accept_fd(ssh_bind sshbind, ssh_session session, socket_t fd) return SSH_ERROR; } ssh_socket_set_fd(session->socket, fd); - ssh_socket_get_poll_handle(session->socket); + handle = ssh_socket_get_poll_handle(session->socket); + if (handle == NULL) { + ssh_set_error_oom(sshbind); + return SSH_ERROR; + } + ssh_socket_set_connected(session->socket, handle); /* We must try to import any keys that could be imported in case * we are not using ssh_bind_listen (which is the other place diff --git a/src/socket.c b/src/socket.c index 78c34fac..a1aa9611 100644 --- a/src/socket.c +++ b/src/socket.c @@ -223,6 +223,15 @@ void ssh_socket_set_callbacks(ssh_socket s, ssh_socket_callbacks callbacks) s->callbacks = callbacks; } +void ssh_socket_set_connected(ssh_socket s, struct ssh_poll_handle_struct *p) +{ + s->state = SSH_SOCKET_CONNECTED; + /* POLLOUT is the event to wait for in a nonblocking connect */ + if (p != NULL) { + ssh_poll_set_events(p, POLLIN | POLLOUT); + } +} + /** * @brief SSH poll callback. This callback will be used when an event * caught on the socket. @@ -345,10 +354,7 @@ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p, /* First, POLLOUT is a sign we may be connected */ if (s->state == SSH_SOCKET_CONNECTING) { SSH_LOG(SSH_LOG_PACKET, "Received POLLOUT in connecting state"); - s->state = SSH_SOCKET_CONNECTED; - if (p != NULL) { - ssh_poll_set_events(p, POLLOUT | POLLIN); - } + ssh_socket_set_connected(s, p); rc = ssh_socket_set_blocking(ssh_socket_get_fd(s)); if (rc < 0) { @@ -949,6 +955,7 @@ int ssh_socket_connect_proxycommand(ssh_socket s, const char *command) { socket_t pair[2]; + ssh_poll_handle h = NULL; int pid; int rc; @@ -971,10 +978,12 @@ ssh_socket_connect_proxycommand(ssh_socket s, const char *command) close(pair[0]); SSH_LOG(SSH_LOG_DEBUG, "ProxyCommand connection pipe: [%d,%d]",pair[0],pair[1]); ssh_socket_set_fd(s, pair[1]); - s->state=SSH_SOCKET_CONNECTED; - s->fd_is_socket=0; - /* POLLOUT is the event to wait for in a nonblocking connect */ - ssh_poll_set_events(ssh_socket_get_poll_handle(s), POLLIN | POLLOUT); + s->fd_is_socket = 0; + h = ssh_socket_get_poll_handle(s); + if (h == NULL) { + return SSH_ERROR; + } + ssh_socket_set_connected(s, h); if (s->callbacks && s->callbacks->connected) { s->callbacks->connected(SSH_SOCKET_CONNECTED_OK, 0, s->callbacks->userdata); }