diff --git a/include/ma_tls.h b/include/ma_tls.h index 9ce49e7c..ec8bc239 100644 --- a/include/ma_tls.h +++ b/include/ma_tls.h @@ -145,6 +145,7 @@ unsigned int ma_tls_get_finger_print(MARIADB_TLS *ctls, char *fp, unsigned int f int ma_tls_get_protocol_version(MARIADB_TLS *ctls); const char *ma_pvio_tls_get_protocol_version(MARIADB_TLS *ctls); int ma_pvio_tls_get_protocol_version_id(MARIADB_TLS *ctls); +void ma_tls_set_connection(MYSQL *mysql); /* Function prototypes */ MARIADB_TLS *ma_pvio_tls_init(MYSQL *mysql); @@ -156,6 +157,7 @@ int ma_pvio_tls_verify_server_cert(MARIADB_TLS *ctls); const char *ma_pvio_tls_cipher(MARIADB_TLS *ctls); my_bool ma_pvio_tls_check_fp(MARIADB_TLS *ctls, const char *fp, const char *fp_list); my_bool ma_pvio_start_ssl(MARIADB_PVIO *pvio); +void ma_pvio_tls_set_connection(MYSQL *mysql); void ma_pvio_tls_end(); #endif /* _ma_tls_h_ */ diff --git a/libmariadb/ma_tls.c b/libmariadb/ma_tls.c index 3090d49e..a5401f95 100644 --- a/libmariadb/ma_tls.c +++ b/libmariadb/ma_tls.c @@ -229,4 +229,9 @@ end: } return rc; } + +void ma_pvio_tls_set_connection(MYSQL *mysql) +{ + return ma_tls_set_connection(mysql); +} #endif /* HAVE_TLS */ diff --git a/libmariadb/mariadb_lib.c b/libmariadb/mariadb_lib.c index 8dc2ec0e..1d090ec1 100644 --- a/libmariadb/mariadb_lib.c +++ b/libmariadb/mariadb_lib.c @@ -1830,6 +1830,11 @@ my_bool STDCALL mariadb_reconnect(MYSQL *mysql) mysql_close(mysql); *mysql=tmp_mysql; mysql->net.pvio->mysql= mysql; +#ifdef HAVE_TLS + /* CONC-604: Set new connection handle */ + if (mysql_get_ssl_cipher(mysql)) + ma_pvio_tls_set_connection(mysql); +#endif ma_net_clear(&mysql->net); mysql->affected_rows= ~(unsigned long long) 0; mysql->info= 0; diff --git a/libmariadb/secure/gnutls.c b/libmariadb/secure/gnutls.c index 1fc977ad..f7e81bf7 100644 --- a/libmariadb/secure/gnutls.c +++ b/libmariadb/secure/gnutls.c @@ -50,12 +50,6 @@ static int my_verify_callback(gnutls_session_t ssl); char tls_library_version[TLS_VERSION_LENGTH]; -struct st_gnutls_data { - MYSQL *mysql; - gnutls_privkey_t key; - gnutls_pcert_st cert; -}; - struct st_cipher_map { unsigned char sid[2]; const char *iana_name; @@ -799,18 +793,6 @@ const struct st_cipher_map tls_ciphers[]= NULL} }; -/* free data assigned to the connection */ -static void free_gnutls_data(struct st_gnutls_data *data) -{ - if (data) - { - if (data->key) - gnutls_privkey_deinit(data->key); - gnutls_pcert_deinit(&data->cert); - free(data); - } -} - /* map the gnutls cipher suite (defined by key exchange algorithm, cipher and mac algorithm) to the corresponding OpenSSL cipher name */ static const char *openssl_cipher_name(gnutls_kx_algorithm_t kx, @@ -912,20 +894,19 @@ static void ma_tls_set_error(MYSQL *mysql, void *ssl, int ssl_errno) alert_name= gnutls_alert_get_name(alert_desc); snprintf(ssl_error, MAX_SSL_ERR_LEN, "fatal alert received: %s", alert_name); - pvio->set_error(mysql, CR_SSL_CONNECTION_ERROR, SQLSTATE_UNKNOWN, 0, - ssl_error); + pvio->set_error(mysql, CR_SSL_CONNECTION_ERROR, SQLSTATE_UNKNOWN, ssl_error); return; } if ((ssl_error_reason= gnutls_strerror(ssl_errno))) { - pvio->set_error(mysql, CR_SSL_CONNECTION_ERROR, SQLSTATE_UNKNOWN, 0, + pvio->set_error(mysql, CR_SSL_CONNECTION_ERROR, SQLSTATE_UNKNOWN, ssl_error_reason); return; } snprintf(ssl_error, MAX_SSL_ERR_LEN, "SSL errno=%d", ssl_errno); - pvio->set_error(mysql, CR_SSL_CONNECTION_ERROR, SQLSTATE_UNKNOWN, - ssl_error); + pvio->set_error(mysql, CR_SSL_CONNECTION_ERROR, SQLSTATE_UNKNOWN, + ssl_error); } @@ -1126,7 +1107,6 @@ void *ma_tls_init(MYSQL *mysql) gnutls_session_t ssl= NULL; gnutls_certificate_credentials_t ctx; int ssl_error= 0; - struct st_gnutls_data *data= NULL; pthread_mutex_lock(&LOCK_gnutls_config); @@ -1139,11 +1119,7 @@ void *ma_tls_init(MYSQL *mysql) if ((ssl_error = gnutls_init(&ssl, GNUTLS_CLIENT | GNUTLS_NONBLOCK | GNUTLS_NO_SIGNAL)) < 0) goto error; - if (!(data= (struct st_gnutls_data *)calloc(1, sizeof(struct st_gnutls_data)))) - goto error; - - data->mysql= mysql; - gnutls_session_set_ptr(ssl, (void *)data); + gnutls_session_set_ptr(ssl, (void *)mysql); /* gnutls_certificate_set_retrieve_function2(GNUTLS_xcred, client_cert_callback); */ @@ -1159,7 +1135,6 @@ void *ma_tls_init(MYSQL *mysql) pthread_mutex_unlock(&LOCK_gnutls_config); return (void *)ssl; error: - free_gnutls_data(data); ma_tls_set_error(mysql, ssl, ssl_error); gnutls_certificate_free_credentials(ctx); if (ssl) @@ -1194,12 +1169,9 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) { gnutls_session_t ssl = (gnutls_session_t)ctls->ssl; my_bool blocking; - MYSQL *mysql; + MYSQL *mysql= (MYSQL *)gnutls_session_get_ptr(ssl); MARIADB_PVIO *pvio; int ret; - struct st_gnutls_data *data; - data= (struct st_gnutls_data *)gnutls_session_get_ptr(ssl); - mysql= data->mysql; if (!mysql) return 1; @@ -1301,9 +1273,13 @@ ssize_t ma_tls_read(MARIADB_TLS *ctls, const uchar* buffer, size_t length) while ((rc= gnutls_record_recv((gnutls_session_t)ctls->ssl, (void *)buffer, length)) <= 0) { if (rc != GNUTLS_E_AGAIN && rc != GNUTLS_E_INTERRUPTED) - return rc; + break; if (pvio->methods->wait_io_or_timeout(pvio, TRUE, pvio->mysql->options.read_timeout) < 1) - return rc; + break; + } + if (rc <= 0) { + MYSQL *mysql= (MYSQL *)gnutls_session_get_ptr(ctls->ssl); + ma_tls_set_error(mysql, ctls->ssl, rc); } return rc; } @@ -1316,9 +1292,13 @@ ssize_t ma_tls_write(MARIADB_TLS *ctls, const uchar* buffer, size_t length) while ((rc= gnutls_record_send((gnutls_session_t)ctls->ssl, (void *)buffer, length)) <= 0) { if (rc != GNUTLS_E_AGAIN && rc != GNUTLS_E_INTERRUPTED) - return rc; + break; if (pvio->methods->wait_io_or_timeout(pvio, TRUE, pvio->mysql->options.write_timeout) < 1) - return rc; + break; + } + if (rc <= 0) { + MYSQL *mysql= (MYSQL *)gnutls_session_get_ptr(ctls->ssl); + ma_tls_set_error(mysql, ctls->ssl, rc); } return rc; } @@ -1328,14 +1308,11 @@ my_bool ma_tls_close(MARIADB_TLS *ctls) if (ctls->ssl) { gnutls_certificate_credentials_t ctx; - struct st_gnutls_data *data= - (struct st_gnutls_data *)gnutls_session_get_ptr(ctls->ssl); /* this would be the correct way, however can't detect afterwards if the socket is closed or not, so we don't send encrypted finish alert. rc= gnutls_bye((gnutls_session_t )ctls->ssl, GNUTLS_SHUT_WR); */ - free_gnutls_data(data); gnutls_credentials_get(ctls->ssl, GNUTLS_CRD_CERTIFICATE, (void **)&ctx); gnutls_certificate_free_keys(ctx); gnutls_certificate_free_cas(ctx); @@ -1372,10 +1349,7 @@ const char *ma_tls_get_cipher(MARIADB_TLS *ctls) static int my_verify_callback(gnutls_session_t ssl) { unsigned int status= 0; - struct st_gnutls_data *data= (struct st_gnutls_data *)gnutls_session_get_ptr(ssl); - MYSQL *mysql; - - mysql= data->mysql; + MYSQL *mysql= (MYSQL *)gnutls_session_get_ptr(ssl); CLEAR_CLIENT_ERROR(mysql); @@ -1419,13 +1393,11 @@ unsigned int ma_tls_get_finger_print(MARIADB_TLS *ctls, char *fp, unsigned int l size_t fp_len= len; const gnutls_datum_t *cert_list; unsigned int cert_list_size; - struct st_gnutls_data *data; if (!ctls || !ctls->ssl) return 0; - data= (struct st_gnutls_data *)gnutls_session_get_ptr(ctls->ssl); - mysql= (MYSQL *)data->mysql; + mysql= (MYSQL *)gnutls_session_get_ptr(ctls->ssl); cert_list = gnutls_certificate_get_peers (ctls->ssl, &cert_list_size); if (cert_list == NULL) @@ -1454,4 +1426,9 @@ int ma_tls_get_protocol_version(MARIADB_TLS *ctls) return gnutls_protocol_get_version(ctls->ssl) - 1; } + +void ma_tls_set_connection(MYSQL *mysql) +{ + (void)gnutls_session_set_ptr(mysql->net.pvio->ctls->ssl, (void *)mysql); +} #endif /* HAVE_GNUTLS */ diff --git a/libmariadb/secure/openssl.c b/libmariadb/secure/openssl.c index 62fcfdae..d0b9d0ea 100644 --- a/libmariadb/secure/openssl.c +++ b/libmariadb/secure/openssl.c @@ -591,16 +591,14 @@ ssize_t ma_tls_read(MARIADB_TLS *ctls, const uchar* buffer, size_t length) { int error= SSL_get_error((SSL *)ctls->ssl, rc); if (error != SSL_ERROR_WANT_READ) - { - if (error == SSL_ERROR_SSL || errno == 0) - { - MYSQL *mysql= SSL_get_app_data(ctls->ssl); - ma_tls_set_error(mysql); - } - return rc; - } + break; if (pvio->methods->wait_io_or_timeout(pvio, TRUE, pvio->mysql->options.read_timeout) < 1) - return rc; + break; + } + if (rc <= 0) + { + MYSQL *mysql= SSL_get_app_data(ctls->ssl); + ma_tls_set_error(mysql); } return rc; } @@ -614,16 +612,14 @@ ssize_t ma_tls_write(MARIADB_TLS *ctls, const uchar* buffer, size_t length) { int error= SSL_get_error((SSL *)ctls->ssl, rc); if (error != SSL_ERROR_WANT_WRITE) - { - if (error == SSL_ERROR_SSL || errno == 0) - { - MYSQL *mysql= SSL_get_app_data(ctls->ssl); - ma_tls_set_error(mysql); - } - return rc; - } + break; if (pvio->methods->wait_io_or_timeout(pvio, TRUE, pvio->mysql->options.write_timeout) < 1) - return rc; + break; + } + if (rc <= 0) + { + MYSQL *mysql= SSL_get_app_data(ctls->ssl); + ma_tls_set_error(mysql); } return rc; } @@ -782,3 +778,8 @@ int ma_tls_get_protocol_version(MARIADB_TLS *ctls) return SSL_version(ctls->ssl) & 0xFF; } +void ma_tls_set_connection(MYSQL *mysql) +{ + (void)SSL_set_app_data(mysql->net.pvio->ctls->ssl, mysql); +} + diff --git a/libmariadb/secure/schannel.c b/libmariadb/secure/schannel.c index ff1833d4..acd7394e 100644 --- a/libmariadb/secure/schannel.c +++ b/libmariadb/secure/schannel.c @@ -560,3 +560,8 @@ unsigned int ma_tls_get_finger_print(MARIADB_TLS *ctls, char *fp, unsigned int l CertFreeCertificateContext(pRemoteCertContext); return len; } + +void ma_tls_set_connection(MYSQL *mysql __attribute__((unused))) +{ + return; +} diff --git a/unittest/libmariadb/connection.c b/unittest/libmariadb/connection.c index 7e35b529..a7631dc8 100644 --- a/unittest/libmariadb/connection.c +++ b/unittest/libmariadb/connection.c @@ -685,13 +685,16 @@ int test_connection_timeout2(MYSQL *unused __attribute__((unused))) SKIP_SKYSQL; SKIP_MAXSCALE; +// SKIP_TLS; mysql= mysql_init(NULL); mysql_options(mysql, MYSQL_OPT_CONNECT_TIMEOUT, (unsigned int *)&timeout); - mysql_options(mysql, MYSQL_INIT_COMMAND, "set @a:=SLEEP(6)"); + mysql_options(mysql, MYSQL_INIT_COMMAND, "set @a:=SLEEP(7)"); start= time(NULL); if (my_test_connect(mysql, hostname, username, password, schema, port, NULL, CLIENT_REMEMBER_OPTIONS)) { + elapsed= time(NULL) - start; + diag("elapsed: %lu", (unsigned long)elapsed); diag("timeout error expected"); return FAIL; } diff --git a/unittest/libmariadb/my_test.h b/unittest/libmariadb/my_test.h index 17907a9e..8b8da62f 100644 --- a/unittest/libmariadb/my_test.h +++ b/unittest/libmariadb/my_test.h @@ -73,6 +73,13 @@ if (IS_SKYSQL(hostname)) \ #define SKIP_NOTLS #endif +#define SKIP_TLS \ +if (force_tls)\ +{\ + diag("Test doesn't work with TLS");\ + return SKIP;\ +} + MYSQL *mysql_default = NULL; /* default connection */ #define IS_MAXSCALE()\