diff --git a/include/mariadb_async.h b/include/mariadb_async.h index d655bd8e..cd5385be 100644 --- a/include/mariadb_async.h +++ b/include/mariadb_async.h @@ -30,10 +30,8 @@ extern ssize_t my_send_async(MARIADB_PVIO *pvio, extern my_bool my_io_wait_async(struct mysql_async_context *b, enum enum_pvio_io_event event, int timeout); #ifdef HAVE_TLS -extern int my_ssl_read_async(struct mysql_async_context *b, MARIADB_TLS *tls, - void *buf, int size); -extern int my_ssl_write_async(struct mysql_async_context *b, MARIADB_TLS *tls, - const void *buf, int size); +extern ssize_t ma_tls_read_async(MARIADB_PVIO *pvio, const uchar *buf, size_t size); +extern ssize_t ma_tls_write_async(MARIADB_PVIO *pvio, const uchar *buf, size_t size); #endif #endif /* MYSQL_ASYNC_H */ diff --git a/libmariadb/ma_pvio.c b/libmariadb/ma_pvio.c index 6c940f3a..477d925c 100644 --- a/libmariadb/ma_pvio.c +++ b/libmariadb/ma_pvio.c @@ -223,7 +223,11 @@ ssize_t ma_pvio_read(MARIADB_PVIO *pvio, uchar *buffer, size_t length) return -1; if (IS_PVIO_ASYNC_ACTIVE(pvio)) { - r= ma_pvio_read_async(pvio, buffer, length); + r= +#if !defined(HAVE_SCHANNEL) + (pvio->ctls) ? ma_tls_read_async(pvio, buffer, length) : +#endif + (ssize_t)ma_pvio_read_async(pvio, buffer, length); goto end; } else @@ -343,18 +347,13 @@ ssize_t ma_pvio_write(MARIADB_PVIO *pvio, const uchar *buffer, size_t length) if (!pvio) return -1; - /* secure connection */ -#ifdef HAVE_TLS - if (pvio->ctls) - { - r= ma_pvio_tls_write(pvio->ctls, buffer, length); - goto end; - } - else -#endif if (IS_PVIO_ASYNC_ACTIVE(pvio)) { - r= ma_pvio_write_async(pvio, buffer, length); + r= +#if !defined(HAVE_SCHANNEL) + (pvio->ctls) ? ma_tls_write_async(pvio, buffer, length) : +#endif + ma_pvio_write_async(pvio, buffer, length); goto end; } else @@ -369,6 +368,14 @@ ssize_t ma_pvio_write(MARIADB_PVIO *pvio, const uchar *buffer, size_t length) ma_pvio_blocking(pvio, TRUE, &old_mode); } } + /* secure connection */ +#ifdef HAVE_TLS + if (pvio->ctls) + { + r= ma_pvio_tls_write(pvio->ctls, buffer, length); + goto end; + } +#endif if (pvio->methods->write) r= pvio->methods->write(pvio, buffer, length); diff --git a/libmariadb/secure/gnutls.c b/libmariadb/secure/gnutls.c index 9ba95849..2035583b 100644 --- a/libmariadb/secure/gnutls.c +++ b/libmariadb/secure/gnutls.c @@ -31,6 +31,8 @@ #include #include #include +#include +#include pthread_mutex_t LOCK_gnutls_config; @@ -1301,6 +1303,55 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) return 0; } +ssize_t ma_tls_write_async(MARIADB_PVIO *pvio, const uchar *buffer, size_t length) +{ + ssize_t res; + struct mysql_async_context *b= pvio->mysql->options.extension->async_context; + MARIADB_TLS *ctls= pvio->ctls; + + for (;;) + { + b->events_to_wait_for= 0; + res= gnutls_record_send((gnutls_session_t)ctls->ssl, (void *)buffer, length); + if (res > 0) + return res; + if (res == GNUTLS_E_AGAIN) + b->events_to_wait_for|= MYSQL_WAIT_WRITE; + else + return res; + if (b->suspend_resume_hook) + (*b->suspend_resume_hook)(TRUE, b->suspend_resume_hook_user_data); + my_context_yield(&b->async_context); + if (b->suspend_resume_hook) + (*b->suspend_resume_hook)(FALSE, b->suspend_resume_hook_user_data); + } +} + + +ssize_t ma_tls_read_async(MARIADB_PVIO *pvio, const uchar *buffer, size_t length) +{ + ssize_t res; + struct mysql_async_context *b= pvio->mysql->options.extension->async_context; + MARIADB_TLS *ctls= pvio->ctls; + + for (;;) + { + b->events_to_wait_for= 0; + res= gnutls_record_recv((gnutls_session_t)ctls->ssl, (void *)buffer, length); + if (res > 0) + return res; + if (res == GNUTLS_E_AGAIN) + b->events_to_wait_for|= MYSQL_WAIT_READ; + else + return res; + if (b->suspend_resume_hook) + (*b->suspend_resume_hook)(TRUE, b->suspend_resume_hook_user_data); + my_context_yield(&b->async_context); + if (b->suspend_resume_hook) + (*b->suspend_resume_hook)(FALSE, b->suspend_resume_hook_user_data); + } +} + ssize_t ma_tls_read(MARIADB_TLS *ctls, const uchar* buffer, size_t length) { return gnutls_record_recv((gnutls_session_t )ctls->ssl, (void *)buffer, length); diff --git a/libmariadb/secure/openssl.c b/libmariadb/secure/openssl.c index e05d4df2..fe599ff4 100644 --- a/libmariadb/secure/openssl.c +++ b/libmariadb/secure/openssl.c @@ -56,6 +56,9 @@ #endif #include +#include +#include + extern my_bool ma_tls_initialized; extern unsigned int mariadb_deinitialize_ssl; @@ -639,13 +642,68 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) return 0; } +static my_bool +ma_tls_async_check_result(int res, struct mysql_async_context *b, SSL *ssl) +{ + int ssl_err; + b->events_to_wait_for= 0; + if (res >= 0) + return 1; + ssl_err= SSL_get_error(ssl, res); + if (ssl_err == SSL_ERROR_WANT_READ) + b->events_to_wait_for|= MYSQL_WAIT_READ; + else if (ssl_err == SSL_ERROR_WANT_WRITE) + b->events_to_wait_for|= MYSQL_WAIT_WRITE; + else + return 1; + if (b->suspend_resume_hook) + (*b->suspend_resume_hook)(TRUE, b->suspend_resume_hook_user_data); + my_context_yield(&b->async_context); + if (b->suspend_resume_hook) + (*b->suspend_resume_hook)(FALSE, b->suspend_resume_hook_user_data); + return 0; +} + +ssize_t ma_tls_read_async(MARIADB_PVIO *pvio, + const unsigned char *buffer, + size_t length) +{ + int res; + struct mysql_async_context *b= pvio->mysql->options.extension->async_context; + MARIADB_TLS *ctls= pvio->ctls; + + for (;;) + { + res= SSL_read((SSL *)ctls->ssl, (void *)buffer, length); + if (ma_tls_async_check_result(res, b, (SSL *)ctls->ssl)) + return res; + } +} + +ssize_t ma_tls_write_async(MARIADB_PVIO *pvio, + const unsigned char *buffer, + size_t length) +{ + int res; + struct mysql_async_context *b= pvio->mysql->options.extension->async_context; + MARIADB_TLS *ctls= pvio->ctls; + + for (;;) + { + res= SSL_write((SSL *)ctls->ssl, (void *)buffer, length); + if (ma_tls_async_check_result(res, b, (SSL *)ctls->ssl)) + return res; + } +} + + ssize_t ma_tls_read(MARIADB_TLS *ctls, const uchar* buffer, size_t length) { return SSL_read((SSL *)ctls->ssl, (void *)buffer, (int)length); } ssize_t ma_tls_write(MARIADB_TLS *ctls, const uchar* buffer, size_t length) -{ +{ return SSL_write((SSL *)ctls->ssl, (void *)buffer, (int)length); } diff --git a/unittest/libmariadb/async.c b/unittest/libmariadb/async.c index 17dc363f..079f19d8 100644 --- a/unittest/libmariadb/async.c +++ b/unittest/libmariadb/async.c @@ -155,7 +155,9 @@ static int async1(MYSQL *unused __attribute__((unused))) mysql_options(&mysql, MYSQL_OPT_READ_TIMEOUT, &default_timeout); mysql_options(&mysql, MYSQL_OPT_CONNECT_TIMEOUT, &default_timeout); mysql_options(&mysql, MYSQL_OPT_WRITE_TIMEOUT, &default_timeout); - mysql_options(&mysql, MYSQL_READ_DEFAULT_GROUP, "myapp"); + mysql_options(&mysql, MYSQL_READ_DEFAULT_GROUP, "myapp"); + if (force_tls) + mysql_ssl_set(&mysql, NULL, NULL, NULL, NULL,NULL); /* Returns 0 when done, else flag for what to wait for when need to block. */ status= mysql_real_connect_start(&ret, &mysql, hostname, username, password, schema, port, socketname, 0); @@ -170,6 +172,12 @@ static int async1(MYSQL *unused __attribute__((unused))) FAIL_IF(!ret, "Failed to mysql_real_connect()"); } + if (force_tls && !mysql_get_ssl_cipher(&mysql)) + { + diag("Error: No tls connection"); + return FAIL; + } + status= mysql_real_query_start(&err, &mysql, SL("SHOW STATUS")); while (status) {