diff --git a/src/poll.c b/src/poll.c index ba8e3706..9c6c097e 100644 --- a/src/poll.c +++ b/src/poll.c @@ -412,7 +412,8 @@ short ssh_poll_get_events(ssh_poll_handle p) /** * @brief Set the events of a poll object. The events will also be propagated - * to an associated poll context. + * to an associated poll context unless the fd is locked. In that case, + * only the POLLOUT can be set. * * @param p Pointer to an already allocated poll object. * @param events Poll events. @@ -420,8 +421,14 @@ short ssh_poll_get_events(ssh_poll_handle p) void ssh_poll_set_events(ssh_poll_handle p, short events) { p->events = events; - if (p->ctx != NULL && !p->lock) { - p->ctx->pollfds[p->x.idx].events = events; + if (p->ctx != NULL) { + if (!p->lock) { + p->ctx->pollfds[p->x.idx].events = events; + } else if (!(p->ctx->pollfds[p->x.idx].events & POLLOUT)) { + /* if locked, allow only setting POLLOUT to prevent recursive + * callbacks */ + p->ctx->pollfds[p->x.idx].events = events & POLLOUT; + } } } @@ -691,6 +698,15 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) return SSH_ERROR; } + /* Ignore any pollin events on locked sockets as that means we are called + * recursively and we only want process the POLLOUT events here to flush + * output buffer */ + for (i = 0; i < ctx->polls_used; i++) { + /* The lock prevents invoking POLLIN events: drop them now */ + if (ctx->pollptrs[i]->lock) { + ctx->pollfds[i].events &= ~POLLIN; + } + } ssh_timestamp_init(&ts); do { int tm = ssh_timeout_update(&ts, timeout); @@ -706,7 +722,7 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) used = ctx->polls_used; for (i = 0; i < used && rc > 0; ) { - if (!ctx->pollfds[i].revents || ctx->pollptrs[i]->lock) { + if (ctx->pollfds[i].revents == 0) { i++; } else { int ret; @@ -716,7 +732,7 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) revents = ctx->pollfds[i].revents; /* avoid having any event caught during callback */ ctx->pollfds[i].events = 0; - p->lock = 1; + p->lock++; if (p->cb && (ret = p->cb(p, fd, revents, p->cb_data)) < 0) { if (ret == -2) { return -1; @@ -727,7 +743,7 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) } else { ctx->pollfds[i].revents = 0; ctx->pollfds[i].events = p->events; - p->lock = 0; + p->lock--; i++; }