diff --git a/.gitignore b/.gitignore index 479bc936..5786a202 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,8 @@ mysql_config/mysql_config *.hex *.dgcov +.*.swp +.gdb_history CTestTestfile.cmake cmake_install.cmake bin/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 0fe85edd..649e700a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -157,9 +157,17 @@ ENDIF() # various defines for generating include/mysql_version.h SET(PROTOCOL_VERSION 10) # we adapted new password option from PHP's mysqlnd ! -SET(MARIADB_CLIENT_VERSION_MAJOR "10") -SET(MARIADB_CLIENT_VERSION_MINOR "2") -SET(MARIADB_CLIENT_VERSION_PATCH "5") +# if C/C is build as subproject inside MariaDB server tree we will +# use the version defined by server +IF(MYSQL_VERSION_MAJOR) + SET(MARIADB_CLIENT_VERSION_MAJOR ${MYSQL_VERSION_MAJOR}) + SET(MARIADB_CLIENT_VERSION_MINOR ${MYSQL_VERSION_MINOR}) + SET(MARIADB_CLIENT_VERSION_PATCH ${MYSQL_VERSION_PATCH}) +ELSE() + SET(MARIADB_CLIENT_VERSION_MAJOR "10") + SET(MARIADB_CLIENT_VERSION_MINOR "2") + SET(MARIADB_CLIENT_VERSION_PATCH "5") +ENDIF() SET(MARIADB_CLIENT_VERSION "${MARIADB_CLIENT_VERSION_MAJOR}.${MARIADB_CLIENT_VERSION_MINOR}.${MARIADB_CLIENT_VERSION_PATCH}") MATH(EXPR MARIADB_VERSION_ID "${MARIADB_CLIENT_VERSION_MAJOR} * 10000 + ${MARIADB_CLIENT_VERSION_MINOR} * 100 + diff --git a/include/ma_global.h b/include/ma_global.h index 5c48f89a..576a318b 100644 --- a/include/ma_global.h +++ b/include/ma_global.h @@ -697,7 +697,7 @@ typedef off_t os_off_t; #if defined(_WIN32) #define socket_errno WSAGetLastError() #define SOCKET_EINTR WSAEINTR -#define SOCKET_EAGAIN WSAEINPROGRESS +#define SOCKET_EAGAIN WSAEWOULDBLOCK #define SOCKET_ENFILE ENFILE #define SOCKET_EMFILE EMFILE #define SOCKET_EWOULDBLOCK WSAEWOULDBLOCK diff --git a/libmariadb/ma_default.c b/libmariadb/ma_default.c index e156b1ea..cc114c35 100644 --- a/libmariadb/ma_default.c +++ b/libmariadb/ma_default.c @@ -93,13 +93,12 @@ my_bool _mariadb_read_options(MYSQL *mysql, const char *config_file, filename= (char *)malloc(FN_REFLEN + 10); if (!_mariadb_get_default_file(filename, FN_REFLEN + 10)) { - free(filename); goto err; } } if (!(file = ma_open(filename, "r", NULL))) - return 1; + goto err; while (ma_gets(buff,sizeof(buff)-1,file)) { diff --git a/libmariadb/ma_ll2str.c b/libmariadb/ma_ll2str.c index 8f3a73b8..b96c7365 100644 --- a/libmariadb/ma_ll2str.c +++ b/libmariadb/ma_ll2str.c @@ -39,7 +39,7 @@ char *ma_ll2str(long long val,char *dst,int radix) if (radix < -36 || radix > -2) return (char*) 0; if (val < 0) { *dst++ = '-'; - val = -val; + val = 0ULL - val; } radix = -radix; } diff --git a/libmariadb/mariadb_lib.c b/libmariadb/mariadb_lib.c index 8b4c1dc4..7590001d 100644 --- a/libmariadb/mariadb_lib.c +++ b/libmariadb/mariadb_lib.c @@ -1595,6 +1595,7 @@ my_bool STDCALL mariadb_reconnect(MYSQL *mysql) tmp_mysql.options.my_cnf_group= tmp_mysql.options.my_cnf_file= NULL; if (IS_MYSQL_ASYNC_ACTIVE(mysql)) { + ctxt= mysql->options.extension->async_context; hook_data.orig_mysql= mysql; hook_data.new_mysql= &tmp_mysql; hook_data.orig_pvio= mysql->net.pvio; @@ -1632,6 +1633,8 @@ my_bool STDCALL mariadb_reconnect(MYSQL *mysql) tmp_mysql.stmts= mysql->stmts; mysql->stmts= NULL; + if (ctxt) + my_context_install_suspend_resume_hook(ctxt, NULL, NULL); /* Don't free options, we moved them to tmp_mysql */ memset(&mysql->options, 0, sizeof(mysql->options)); mysql->free_me=0; diff --git a/libmariadb/mariadb_stmt.c b/libmariadb/mariadb_stmt.c index 7fac4af4..ad82f054 100644 --- a/libmariadb/mariadb_stmt.c +++ b/libmariadb/mariadb_stmt.c @@ -313,12 +313,33 @@ static int stmt_cursor_fetch(MYSQL_STMT *stmt, uchar **row) return(MYSQL_NO_DATA); } +/* flush one result set */ void mthd_stmt_flush_unbuffered(MYSQL_STMT *stmt) { ulong packet_len; + int in_resultset= stmt->state > MYSQL_STMT_EXECUTED && + stmt->state < MYSQL_STMT_FETCH_DONE; while ((packet_len = ma_net_safe_read(stmt->mysql)) != packet_error) - if (packet_len < 8 && stmt->mysql->net.read_pos[0] == 254) - return; + { + uchar *pos= stmt->mysql->net.read_pos; + if (!in_resultset && *pos == 0) /* OK */ + { + pos++; + net_field_length(&pos); + net_field_length(&pos); + stmt->mysql->server_status= uint2korr(pos); + goto end; + } + if (packet_len < 8 && *pos == 254) /* EOF */ + { + stmt->mysql->server_status= uint2korr(pos + 3); + if (in_resultset) + goto end; + in_resultset= 1; + } + } +end: + stmt->state= MYSQL_STMT_FETCH_DONE; } int mthd_stmt_fetch_to_bind(MYSQL_STMT *stmt, unsigned char *row) @@ -1136,7 +1157,9 @@ static my_bool net_stmt_close(MYSQL_STMT *stmt, my_bool remove) /* check if all data are fetched */ if (stmt->mysql->status != MYSQL_STATUS_READY) { - stmt->mysql->methods->db_stmt_flush_unbuffered(stmt); + do { + stmt->mysql->methods->db_stmt_flush_unbuffered(stmt); + } while(mysql_stmt_more_results(stmt)); stmt->mysql->status= MYSQL_STATUS_READY; } if (stmt->state > MYSQL_STMT_INITTED) @@ -1782,7 +1805,10 @@ int STDCALL mysql_stmt_execute(MYSQL_STMT *stmt) } if (stmt->state > MYSQL_STMT_WAITING_USE_OR_STORE && stmt->state < MYSQL_STMT_FETCH_DONE && !stmt->result.data) { - mysql->methods->db_stmt_flush_unbuffered(stmt); + if (!stmt->cursor_exists) + do { + stmt->mysql->methods->db_stmt_flush_unbuffered(stmt); + } while(mysql_stmt_more_results(stmt)); stmt->state= MYSQL_STMT_PREPARED; stmt->mysql->status= MYSQL_STATUS_READY; } @@ -2238,9 +2264,11 @@ int STDCALL mariadb_stmt_execute_direct(MYSQL_STMT *stmt, /* read execute response packet */ return stmt_read_execute_response(stmt); fail: - stmt->state= MYSQL_STMT_INITTED; SET_CLIENT_STMT_ERROR(stmt, mysql->net.last_errno, mysql->net.sqlstate, mysql->net.last_error); - mysql->methods->db_stmt_flush_unbuffered(stmt); + do { + stmt->mysql->methods->db_stmt_flush_unbuffered(stmt); + } while(mysql_stmt_more_results(stmt)); + stmt->state= MYSQL_STMT_INITTED; return 1; } diff --git a/libmariadb/secure/ma_schannel.c b/libmariadb/secure/ma_schannel.c index 6c27cba6..e1e33d42 100644 --- a/libmariadb/secure/ma_schannel.c +++ b/libmariadb/secure/ma_schannel.c @@ -432,7 +432,7 @@ SECURITY_STATUS ma_schannel_handshake_loop(MARIADB_PVIO *pvio, my_bool InitialRe cbIoBuffer = 0; fDoRead = InitialRead; - /* handshake loop: We will leave a handshake is finished + /* handshake loop: We will leave if handshake is finished or an error occurs */ rc = SEC_I_CONTINUE_NEEDED; @@ -656,6 +656,7 @@ SECURITY_STATUS ma_schannel_client_handshake(MARIADB_TLS *ctls) if(BuffersOut.cbBuffer != 0 && BuffersOut.pvBuffer != NULL) { ssize_t nbytes = (DWORD)pvio->methods->write(pvio, (uchar *)BuffersOut.pvBuffer, (size_t)BuffersOut.cbBuffer); + if (nbytes <= 0) { sRet= SEC_E_INTERNAL_ERROR; diff --git a/libmariadb/secure/schannel.c b/libmariadb/secure/schannel.c index 080bfcdc..e9292bdb 100644 --- a/libmariadb/secure/schannel.c +++ b/libmariadb/secure/schannel.c @@ -26,24 +26,136 @@ extern my_bool ma_tls_initialized; +#define PROT_SSL3 1 +#define PROT_TLS1_0 2 +#define PROT_TLS1_2 4 +#define PROT_TLS1_3 8 + static struct { - ALG_ID algs[3]; /* exchange, encryption, hash */ + DWORD cipher_id; + DWORD protocol; + const char *iana_name; const char *openssl_name; + ALG_ID algs[4]; /* exchange, encryption, hash, signature */ } cipher_map[] = { - {{CALG_RSA_KEYX,CALG_AES_256,CALG_SHA}, "AES256-SHA"}, - {{CALG_RSA_KEYX,CALG_AES_128,CALG_SHA}, "AES128-SHA"}, - {{CALG_RSA_KEYX,CALG_RC4,CALG_SHA}, "RC4-SHA"}, - {{CALG_RSA_KEYX,CALG_3DES,CALG_SHA}, "DES-CBC3-SHA"}, - {{CALG_RSA_KEYX,CALG_AES_128,CALG_SHA_256 }, "AES128-SHA256" }, - {{CALG_DH_EPHEM,CALG_AES_256,CALG_SHA_384}, "DHE-RSA-AES256-GCM-SHA384"}, - {{CALG_DH_EPHEM,CALG_AES_128,CALG_SHA_256}, "DHE-RSA-AES128-GCM-SHA256"}, - {{CALG_DH_EPHEM,CALG_AES_256,CALG_SHA}, "DHE-RSA-AES256-SHA"}, - {{CALG_DH_EPHEM,CALG_AES_128,CALG_SHA}, "DHE-RSA-AES128-SHA"}, - {{CALG_RSA_KEYX,CALG_AES_128,CALG_SHA_256}, "AES128-GCM-SHA256"}, - {{CALG_RSA_KEYX,CALG_RC4,0}, "RC4-MD5"}, + { + 0x0002, + PROT_TLS1_0 | PROT_TLS1_2 | PROT_SSL3, + "TLS_RSA_WITH_NULL_SHA", "NULL-SHA", + { CALG_RSA_KEYX, 0, CALG_SHA1, CALG_RSA_SIGN } + }, + { + 0x0004, + PROT_TLS1_0 | PROT_TLS1_2 | PROT_SSL3, + "TLS_RSA_WITH_RC4_128_MD5", "RC4-MD5", + { CALG_RSA_KEYX, CALG_RC4, CALG_MD5, CALG_RSA_SIGN } + }, + { + 0x0005, + PROT_TLS1_0 | PROT_TLS1_2 | PROT_SSL3, + "TLS_RSA_WITH_RC4_128_SHA", "RC4-SHA", + { CALG_RSA_KEYX, CALG_RC4, CALG_SHA1, CALG_RSA_SIGN } + }, + { + 0x000A, + PROT_SSL3, + "TLS_RSA_WITH_3DES_EDE_CBC_SHA", "DES-CBC3-SHA", + {CALG_RSA_KEYX, CALG_3DES, CALG_SHA1, CALG_DSS_SIGN} + }, + { + 0x0013, + PROT_TLS1_0 | PROT_TLS1_2 | PROT_SSL3, + "TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA", "EDH-DSS-DES-CBC3-SHA", + { CALG_DH_EPHEM, CALG_3DES, CALG_SHA1, CALG_DSS_SIGN } + }, + { + 0x002F, + PROT_SSL3 | PROT_TLS1_0 | PROT_TLS1_2, + "TLS_RSA_WITH_AES_128_CBC_SHA", "AES128-SHA", + { CALG_RSA_KEYX, CALG_AES_128, CALG_SHA, CALG_RSA_SIGN} + }, + { + 0x0032, + PROT_TLS1_0 | PROT_TLS1_2, + "TLS_DHE_DSS_WITH_AES_128_CBC_SHA", "DHE-DSS-AES128-SHA", + { CALG_DH_EPHEM, CALG_AES_128, CALG_SHA1, CALG_RSA_SIGN } + }, + { + 0x0033, + PROT_TLS1_0 | PROT_TLS1_2, + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA", "DHE-RSA-AES128-SHA", + { CALG_DH_EPHEM, CALG_AES_128, CALG_SHA1, CALG_RSA_SIGN } + }, + { + 0x0035, + PROT_TLS1_0 | PROT_TLS1_2, + "TLS_RSA_WITH_AES_256_CBC_SHA", "AES256-SHA", + { CALG_RSA_KEYX, CALG_AES_256, CALG_SHA1, CALG_RSA_SIGN } + }, + { + 0x0038, + PROT_TLS1_0 | PROT_TLS1_2, + "TLS_DHE_DSS_WITH_AES_256_CBC_SHA", "DHE-DSS-AES256-SHA", + { CALG_DH_EPHEM, CALG_AES_256, CALG_SHA1, CALG_DSS_SIGN } + }, + { + 0x0039, + PROT_TLS1_0 | PROT_TLS1_2, + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA", "DHE-RSA-AES256-SHA", + { CALG_DH_EPHEM, CALG_AES_256, CALG_SHA1, CALG_RSA_SIGN } + }, + { + 0x003B, + PROT_TLS1_2, + "TLS_RSA_WITH_NULL_SHA256", "NULL-SHA256", + { CALG_RSA_KEYX, 0, CALG_SHA_256, CALG_RSA_SIGN } + }, + { + 0x003C, + PROT_TLS1_2, + "TLS_RSA_WITH_AES_128_CBC_SHA256", "AES128-SHA256", + { CALG_RSA_KEYX, CALG_AES_128, CALG_SHA_256, CALG_RSA_SIGN } + }, + { + 0x003D, + PROT_TLS1_2, + "TLS_RSA_WITH_AES_256_CBC_SHA256", "AES256-SHA256", + { CALG_RSA_KEYX, CALG_AES_256, CALG_SHA_256, CALG_RSA_SIGN } + }, + { + 0x0040, + PROT_TLS1_2, + "TLS_DHE_DSS_WITH_AES_128_CBC_SHA256", "DHE-DSS-AES128-SHA256", + { CALG_DH_EPHEM, CALG_AES_128, CALG_SHA_256, CALG_DSS_SIGN } + }, + { + 0x009C, + PROT_TLS1_2, + "TLS_RSA_WITH_AES_128_GCM_SHA256", "AES128-GCM-SHA256", + { CALG_RSA_KEYX, CALG_AES_128, CALG_SHA_256, CALG_RSA_SIGN } + }, + { + 0x009D, + PROT_TLS1_2, + "TLS_RSA_WITH_AES_256_GCM_SHA384", "AES256-GCM-SHA384", + { CALG_RSA_KEYX, CALG_AES_256, CALG_SHA_384, CALG_RSA_SIGN } + }, + { + 0x009E, + PROT_TLS1_2, + "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", "DHE-RSA-AES128-GCM-SHA256", + { CALG_DH_EPHEM, CALG_AES_128, CALG_SHA_256, CALG_RSA_SIGN } + }, + { + 0x009F, + PROT_TLS1_2, + "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384", "DHE-RSA-AES256-GCM-SHA384", + { CALG_DH_EPHEM, CALG_AES_256, CALG_SHA_384, CALG_RSA_SIGN } + } + }; #define MAX_ALG_ID 50 @@ -134,7 +246,7 @@ void *ma_tls_init(MYSQL *mysql) /* Maps between openssl suite names and schannel alg_ids. - Every suite has 3 algorithms (for exchange, encryption, hash). + Every suite has 4 algorithms (for exchange, encryption, hash and signing). The input string is a set of suite names (openssl), separated by ':' @@ -143,21 +255,32 @@ void *ma_tls_init(MYSQL *mysql) The function returns number of elements written to the 'arr'. */ +static struct _tls_version { + const char *tls_version; + DWORD protocol; +} tls_version[]= { + {"TLSv1.0", PROT_TLS1_0}, + {"TLSv1.2", PROT_TLS1_2}, + {"TLSv1.3", PROT_TLS1_3}, + {"SSLv3", PROT_SSL3} +}; -static size_t set_cipher(char * cipher_str, ALG_ID *arr , size_t arr_size) +static size_t set_cipher(char * cipher_str, DWORD protocol, ALG_ID *arr , size_t arr_size) { - char *token = strtok(cipher_str, ":"); size_t pos = 0; + while (token) { size_t i; + for(i = 0; i < sizeof(cipher_map)/sizeof(cipher_map[0]) ; i++) { - if(pos + 3 < arr_size && strcmp(cipher_map[i].openssl_name, token) == 0) + if(pos + 4 < arr_size && strcmp(cipher_map[i].openssl_name, token) == 0 || + (cipher_map[i].protocol <= protocol)) { - memcpy(arr + pos, cipher_map[i].algs, sizeof(cipher_map[i].algs)); - pos += 3; + memcpy(arr + pos, cipher_map[i].algs, sizeof(ALG_ID)* 4); + pos += 4; break; } } @@ -168,7 +291,6 @@ static size_t set_cipher(char * cipher_str, ALG_ID *arr , size_t arr_size) my_bool ma_tls_connect(MARIADB_TLS *ctls) { - my_bool blocking; MYSQL *mysql; SCHANNEL_CRED Cred; MARIADB_PVIO *pvio; @@ -177,17 +299,13 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) SECURITY_STATUS sRet; ALG_ID AlgId[MAX_ALG_ID]; WORD validTokens = 0; - + if (!ctls || !ctls->pvio) return 1;; pvio= ctls->pvio; sctx= (SC_CTX *)ctls->ssl; - /* Set socket to blocking if not already set */ - if (!(blocking= pvio->methods->is_blocking(pvio))) - pvio->methods->blocking(pvio, TRUE, 0); - mysql= pvio->mysql; if (ma_tls_set_client_certs(ctls)) @@ -198,40 +316,51 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) /* Set cipher */ if (mysql->options.ssl_cipher) { - Cred.cSupportedAlgs = (DWORD)set_cipher(mysql->options.ssl_cipher, AlgId, MAX_ALG_ID); + int i; + DWORD protocol = 0; + + /* check if a protocol was specified as a cipher: + * In this case don't allow cipher suites which belong to newer protocols + * Please note: There are no cipher suites for TLS1.1 + */ + for (i = 0; i < sizeof(tls_version) / sizeof(tls_version[0]); i++) + { + if (!stricmp(mysql->options.ssl_cipher, tls_version[i].tls_version)) + protocol |= tls_version[i].protocol; + } + memset(AlgId, 0, MAX_ALG_ID * sizeof(ALG_ID)); + Cred.cSupportedAlgs = (DWORD)set_cipher(mysql->options.ssl_cipher, protocol, AlgId, MAX_ALG_ID); if (Cred.cSupportedAlgs) { Cred.palgSupportedAlgs = AlgId; } - else + else if (!protocol) { ma_schannel_set_sec_error(pvio, SEC_E_ALGORITHM_MISMATCH); goto end; } } + Cred.dwVersion= SCHANNEL_CRED_VERSION; - if (mysql->options.extension) - { - Cred.dwMinimumCipherStrength = mysql->options.extension->tls_cipher_strength; - } - Cred.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK | SCH_SEND_ROOT_CERT | - SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION; + Cred.dwFlags = SCH_CRED_NO_SERVERNAME_CHECK | SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION; + if (sctx->client_cert_ctx) { Cred.cCreds = 1; Cred.paCred = &sctx->client_cert_ctx; } - Cred.grbitEnabledProtocols= SP_PROT_TLS1_0|SP_PROT_TLS1_1; if (mysql->options.extension && mysql->options.extension->tls_version) { if (strstr("TLSv1.0", mysql->options.extension->tls_version)) - Cred.grbitEnabledProtocols|= SP_PROT_TLS1_0; + Cred.grbitEnabledProtocols|= SP_PROT_TLS1_0_CLIENT; if (strstr("TLSv1.1", mysql->options.extension->tls_version)) - Cred.grbitEnabledProtocols|= SP_PROT_TLS1_1; + Cred.grbitEnabledProtocols|= SP_PROT_TLS1_1_CLIENT; if (strstr("TLSv1.2", mysql->options.extension->tls_version)) - Cred.grbitEnabledProtocols|= SP_PROT_TLS1_2; + Cred.grbitEnabledProtocols|= SP_PROT_TLS1_2_CLIENT; } + if (!Cred.grbitEnabledProtocols) + Cred.grbitEnabledProtocols = SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_1_CLIENT; if ((sRet= AcquireCredentialsHandleA(NULL, UNISP_NAME_A, SECPKG_CRED_OUTBOUND, NULL, &Cred, NULL, NULL, &sctx->CredHdl, NULL)) != SEC_E_OK) @@ -243,15 +372,13 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) if (ma_schannel_client_handshake(ctls) != SEC_E_OK) goto end; - if (!ma_schannel_verify_certs(sctx)) goto end; - + return 0; end: - if (rc && sctx->IoBufferSize) LocalFree(sctx->IoBuffer); sctx->IoBufferSize= 0; @@ -378,23 +505,21 @@ end: return rc; } -static const char *cipher_name(ALG_ID keyxch, ALG_ID cipher, ALG_ID hash) +static const char *cipher_name(const SecPkgContext_CipherInfo *CipherInfo) { int i; for(i = 0; i < sizeof(cipher_map)/sizeof(cipher_map[0]) ; i++) { - if (cipher_map[i].algs[0] == keyxch && - cipher_map[i].algs[1] == cipher && - cipher_map[i].algs[2] == hash) + if (CipherInfo->dwCipherSuite == cipher_map[i].cipher_id) return cipher_map[i].openssl_name; } - return "unknown cipher"; + return ""; }; const char *ma_tls_get_cipher(MARIADB_TLS *ctls) { - SecPkgContext_ConnectionInfo cinfo; + SecPkgContext_CipherInfo CipherInfo = { SECPKGCONTEXT_CIPHERINFO_V1 }; SECURITY_STATUS sRet; SC_CTX *sctx; DWORD i= 0; @@ -404,10 +529,11 @@ const char *ma_tls_get_cipher(MARIADB_TLS *ctls) sctx= (SC_CTX *)ctls->ssl; - sRet= QueryContextAttributes(&sctx->ctxt, SECPKG_ATTR_CONNECTION_INFO, (PVOID)&cinfo); + sRet= QueryContextAttributes(&sctx->ctxt, SECPKG_ATTR_CIPHER_INFO, (PVOID)&CipherInfo); if (sRet != SEC_E_OK) return NULL; - return cipher_name(cinfo.aiExch, cinfo.aiCipher, cinfo.aiHash); + + return cipher_name(&CipherInfo); } unsigned int ma_tls_get_finger_print(MARIADB_TLS *ctls, char *fp, unsigned int len) diff --git a/plugins/pvio/pvio_socket.c b/plugins/pvio/pvio_socket.c index 251d7eb7..151e5e0a 100644 --- a/plugins/pvio/pvio_socket.c +++ b/plugins/pvio/pvio_socket.c @@ -52,9 +52,12 @@ #include #include #include +#define IS_SOCKET_EINTR(err) (err == SOCKET_EINTR) #else #include #define O_NONBLOCK 1 +#define MSG_DONTWAIT 0 +#define IS_SOCKET_EINTR(err) 0 #endif #ifndef SOCKET_ERROR @@ -63,6 +66,16 @@ #define DNS_TIMEOUT 30 +#ifndef O_NONBLOCK +#if defined(O_NDELAY) +#define O_NONBLOCK O_NODELAY +#elif defined (O_FNDELAY) +#define O_NONBLOCK O_FNDELAY +#else +#error socket blocking is not supported on this platform +#endif +#endif + /* Function prototypes */ my_bool pvio_socket_set_timeout(MARIADB_PVIO *pvio, enum enum_pvio_timeout type, int timeout); @@ -88,6 +101,8 @@ static int pvio_socket_init(char *unused1, int unused3, va_list); static int pvio_socket_end(void); +static ssize_t ma_send(my_socket socket, const uchar *buffer, size_t length, int flags); +static ssize_t ma_recv(my_socket socket, uchar *buffer, size_t length, int flags); struct st_ma_pvio_methods pvio_socket_methods= { pvio_socket_set_timeout, @@ -265,43 +280,26 @@ int pvio_socket_get_timeout(MARIADB_PVIO *pvio, enum enum_pvio_timeout type) */ ssize_t pvio_socket_read(MARIADB_PVIO *pvio, uchar *buffer, size_t length) { - ssize_t r= -1; -#ifndef _WIN32 - int read_flags= 0; -#endif - struct st_pvio_socket *csock= NULL; + ssize_t r; + int read_flags= MSG_DONTWAIT; + struct st_pvio_socket *csock; + int timeout; if (!pvio || !pvio->data) return -1; csock= (struct st_pvio_socket *)pvio->data; + timeout = pvio->timeout[PVIO_READ_TIMEOUT]; -#ifndef _WIN32 - if (pvio_socket_wait_io_or_timeout(pvio, TRUE, pvio->timeout[PVIO_READ_TIMEOUT]) < 1) - return -1; - do { - r= recv(csock->socket, (void *)buffer, length, read_flags); - } while (r == -1 && errno == EINTR); -#else + while ((r = ma_recv(csock->socket, (void *)buffer, length, read_flags)) == -1) { - WSABUF wsaData; - DWORD flags= 0, - dwBytes= 0; + int err = socket_errno; + if ((err != SOCKET_EAGAIN && err != SOCKET_EWOULDBLOCK) || timeout == 0) + return r; - /* clear error */ - errno= 0; - wsaData.len = (u_long)length; - wsaData.buf = (char*) buffer; - - r = WSARecv(csock->socket, &wsaData, 1, &dwBytes, &flags, NULL, NULL); - if (r == SOCKET_ERROR) - { - errno= WSAGetLastError(); + if (pvio_socket_wait_io_or_timeout(pvio, TRUE, timeout) < 1) return -1; - } - r= (ssize_t)dwBytes; } -#endif return r; } /* }}} */ @@ -351,22 +349,33 @@ ssize_t pvio_socket_async_read(MARIADB_PVIO *pvio, uchar *buffer, size_t length) } /* }}} */ -#ifndef _WIN32 -ssize_t ma_send(int socket, const uchar *buffer, size_t length, int flags) +static ssize_t ma_send(my_socket socket, const uchar *buffer, size_t length, int flags) { ssize_t r; -#if !defined(MSG_NOSIGNAL) && !defined(SO_NOSIGPIPE) +#if !defined(MSG_NOSIGNAL) && !defined(SO_NOSIGPIPE) && !defined(_WIN32) struct sigaction act, oldact; act.sa_handler= SIG_IGN; sigaction(SIGPIPE, &act, &oldact); #endif - r= send(socket, buffer, length, flags); -#if !defined(MSG_NOSIGNAL) && !defined(SO_NOSIGPIPE) + do { + r = send(socket, buffer, IF_WIN((int)length,length), flags); + } + while (r == -1 && IS_SOCKET_EINTR(socket_errno)); +#if !defined(MSG_NOSIGNAL) && !defined(SO_NOSIGPIPE) && !defined(_WIN32) sigaction(SIGPIPE, &oldact, NULL); #endif return r; } -#endif + +static ssize_t ma_recv(my_socket socket, uchar *buffer, size_t length, int flags) +{ + ssize_t r; + do { + r = recv(socket, buffer, IF_WIN((int)length, length), flags); + } + while (r == -1 && IS_SOCKET_EINTR(socket_errno)); + return r; +} /* {{{ pvio_socket_async_write */ /* @@ -439,49 +448,27 @@ ssize_t pvio_socket_async_write(MARIADB_PVIO *pvio, const uchar *buffer, size_t */ ssize_t pvio_socket_write(MARIADB_PVIO *pvio, const uchar *buffer, size_t length) { - ssize_t r= -1; - struct st_pvio_socket *csock= NULL; -#ifndef _WIN32 + ssize_t r; + struct st_pvio_socket *csock; + int timeout; int send_flags= MSG_DONTWAIT; #ifdef MSG_NOSIGNAL send_flags|= MSG_NOSIGNAL; -#endif #endif if (!pvio || !pvio->data) return -1; csock= (struct st_pvio_socket *)pvio->data; + timeout = pvio->timeout[PVIO_WRITE_TIMEOUT]; -#ifndef _WIN32 - do { - r= ma_send(csock->socket, buffer, length, send_flags); - } while (r == -1 && errno == EINTR); - - while (r == -1 && (errno == EAGAIN || errno == EWOULDBLOCK) && - pvio->timeout[PVIO_WRITE_TIMEOUT] != 0) + while ((r = ma_send(csock->socket, (void *)buffer, length,send_flags)) == -1) { - if (pvio_socket_wait_io_or_timeout(pvio, FALSE, pvio->timeout[PVIO_WRITE_TIMEOUT]) < 1) + int err = socket_errno; + if ((err != SOCKET_EAGAIN && err != SOCKET_EWOULDBLOCK)|| timeout == 0) + return r; + if (pvio_socket_wait_io_or_timeout(pvio, FALSE, timeout) < 1) return -1; - do { - r= ma_send(csock->socket, buffer, length, send_flags); - } while (r == -1 && errno == EINTR); } -#else - { - WSABUF wsaData; - DWORD dwBytes= 0; - - wsaData.len = (u_long)length; - wsaData.buf = (char*) buffer; - - r = WSASend(csock->socket, &wsaData, 1, &dwBytes, 0, NULL, NULL); - if (r == SOCKET_ERROR) { - errno= WSAGetLastError(); - return -1; - } - r= dwBytes; - } -#endif return r; } /* }}} */ @@ -539,6 +526,7 @@ int pvio_socket_wait_io_or_timeout(MARIADB_PVIO *pvio, my_bool is_read, int time else if (rc == 0) { rc= SOCKET_ERROR; + WSASetLastError(WSAETIMEDOUT); errno= ETIMEDOUT; } else if (FD_ISSET(csock->socket, &exc_fds)) @@ -547,6 +535,7 @@ int pvio_socket_wait_io_or_timeout(MARIADB_PVIO *pvio, my_bool is_read, int time int len = sizeof(int); if (getsockopt(csock->socket, SOL_SOCKET, SO_ERROR, (char *)&err, &len) != SOCKET_ERROR) { + WSASetLastError(err); errno= err; } rc= SOCKET_ERROR; @@ -559,50 +548,42 @@ int pvio_socket_wait_io_or_timeout(MARIADB_PVIO *pvio, my_bool is_read, int time my_bool pvio_socket_blocking(MARIADB_PVIO *pvio, my_bool block, my_bool *previous_mode) { - int *sd_flags, save_flags; - my_bool tmp; - struct st_pvio_socket *csock= NULL; + my_bool is_blocking; + struct st_pvio_socket *csock; + int new_fcntl_mode; if (!pvio || !pvio->data) return 1; - csock= (struct st_pvio_socket *)pvio->data; - sd_flags= &csock->fcntl_mode; - save_flags= csock->fcntl_mode; + csock = (struct st_pvio_socket *)pvio->data; - if (!previous_mode) - previous_mode= &tmp; + is_blocking = !(csock->fcntl_mode & O_NONBLOCK); + if (previous_mode) + *previous_mode = is_blocking; + + if (is_blocking == block) + return 0; + + if (block) + new_fcntl_mode = csock->fcntl_mode & ~O_NONBLOCK; + else + new_fcntl_mode = csock->fcntl_mode | O_NONBLOCK; #ifdef _WIN32 - *previous_mode= (*sd_flags & O_NONBLOCK) != 0; - *sd_flags = (block) ? *sd_flags & ~O_NONBLOCK : *sd_flags | O_NONBLOCK; { - ulong arg= 1 - block; + ulong arg = block ? 0 : 1; if (ioctlsocket(csock->socket, FIONBIO, (void *)&arg)) { - csock->fcntl_mode= save_flags; return(WSAGetLastError()); } } #else -#if defined(O_NONBLOCK) - *previous_mode= (*sd_flags & O_NONBLOCK) != 0; - *sd_flags = (block) ? *sd_flags & ~O_NONBLOCK : *sd_flags | O_NONBLOCK; -#elif defined(O_NDELAY) - *previous_mode= (*sd_flags & O_NODELAY) != 0; - *sd_flags = (block) ? *sd_flags & ~O_NODELAY : *sd_flags | O_NODELAY; -#elif defined(FNDELAY) - *previous_mode= (*sd_flags & O_FNDELAY) != 0; - *sd_flags = (block) ? *sd_flags & ~O_FNDELAY : *sd_flags | O_FNDELAY; -#else -#error socket blocking is not supported on this platform -#endif - if (fcntl(csock->socket, F_SETFL, *sd_flags) == -1) + if (fcntl(csock->socket, F_SETFL, new_fcntl_mode) == -1) { - csock->fcntl_mode= save_flags; return errno; } #endif + csock->fcntl_mode = new_fcntl_mode; return 0; } diff --git a/unittest/libmariadb/CMakeLists.txt b/unittest/libmariadb/CMakeLists.txt index d9718672..82845bc9 100644 --- a/unittest/libmariadb/CMakeLists.txt +++ b/unittest/libmariadb/CMakeLists.txt @@ -26,7 +26,7 @@ INCLUDE_DIRECTORIES(${CC_SOURCE_DIR}/include ${CC_SOURCE_DIR}/unittest/libmariadb) ADD_DEFINITIONS(-DLIBMARIADB) -SET(API_TESTS "performance" "basic-t" "fetch" "charset" "logs" "cursor" "errors" "view" "ps" "ps_bugs" "sp" "result" "connection" "misc" "ps_new" "sqlite3" "thread" "features-10_2" "bulk1" ) +SET(API_TESTS "bulk1" "performance" "basic-t" "fetch" "charset" "logs" "cursor" "errors" "view" "ps" "ps_bugs" "sp" "result" "connection" "misc" "ps_new" "sqlite3" "thread" "features-10_2" "bulk1") IF(WITH_DYNCOL) SET(API_TESTS ${API_TESTS} "dyncol") ENDIF() diff --git a/unittest/libmariadb/bulk1.c b/unittest/libmariadb/bulk1.c index efaa1850..895f4da8 100644 --- a/unittest/libmariadb/bulk1.c +++ b/unittest/libmariadb/bulk1.c @@ -402,6 +402,7 @@ static int bulk5(MYSQL *mysql) res= mysql_store_result(mysql); rows= (unsigned long)mysql_num_rows(res); + diag("rows: %lu", rows); mysql_free_result(res); FAIL_IF(rows != 5, "expected 5 rows"); diff --git a/unittest/libmariadb/connection.c b/unittest/libmariadb/connection.c index 5b524893..2295f90b 100644 --- a/unittest/libmariadb/connection.c +++ b/unittest/libmariadb/connection.c @@ -988,7 +988,7 @@ static int test_sess_track_db(MYSQL *mysql) return OK; } - +#ifndef WIN32 static int test_unix_socket_close(MYSQL *unused __attribute__((unused))) { MYSQL *mysql= mysql_init(NULL); @@ -1016,6 +1016,7 @@ static int test_unix_socket_close(MYSQL *unused __attribute__((unused))) mysql_close(mysql); return OK; } +#endif static int test_reset(MYSQL *mysql) { @@ -1069,9 +1070,32 @@ static int test_reset(MYSQL *mysql) return OK; } +static int test_mdev12446(MYSQL *my __attribute__((unused))) +{ + /* + if specified file didn't exist, valgrind reported a leak, + if no file was specified and no default file is installed, + C/C crashed due to double free. + */ + MYSQL *mysql= mysql_init(NULL); + mysql_options(mysql, MYSQL_READ_DEFAULT_FILE, "file.notfound"); + FAIL_IF(!my_test_connect(mysql, hostname, username, password, schema, + port, socketname, 0), mysql_error(mysql)); + mysql_close(mysql); + mysql= mysql_init(NULL); + mysql_options(mysql, MYSQL_READ_DEFAULT_GROUP, "notfound"); + FAIL_IF(!my_test_connect(mysql, hostname, username, password, schema, + port, socketname, 0), mysql_error(mysql)); + mysql_close(mysql); + return OK; +} + struct my_tests_st my_tests[] = { + {"test_mdev12446", test_mdev12446, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"test_reset", test_reset, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, +#ifndef WIN32 {"test_unix_socket_close", test_unix_socket_close, TEST_CONNECTION_NONE, 0, NULL, NULL}, +#endif {"test_sess_track_db", test_sess_track_db, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"test_get_options", test_get_options, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"test_wrong_bind_address", test_wrong_bind_address, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, diff --git a/unittest/libmariadb/features-10_2.c b/unittest/libmariadb/features-10_2.c index 12bf46f5..d46e6acc 100644 --- a/unittest/libmariadb/features-10_2.c +++ b/unittest/libmariadb/features-10_2.c @@ -190,8 +190,31 @@ static int conc_218(MYSQL *mysql) return OK; } +static int test_cursor(MYSQL *mysql) +{ + int rc; + MYSQL_STMT *stmt; + unsigned int prefetch_rows= 1; + unsigned long cursor_type= CURSOR_TYPE_READ_ONLY; + + stmt= mysql_stmt_init(mysql); + rc= mysql_stmt_attr_set(stmt, STMT_ATTR_CURSOR_TYPE, &cursor_type); + check_stmt_rc(rc, stmt); + rc= mysql_stmt_attr_set(stmt, STMT_ATTR_PREFETCH_ROWS, &prefetch_rows); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "SELECT 1 FROM DUAL UNION SELECT 2 FROM DUAL", -1); + check_stmt_rc(rc, stmt); + rc= mysql_stmt_fetch(stmt); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "SELECT 1 FROM DUAL UNION SELECT 2 FROM DUAL", -1); + check_stmt_rc(rc, stmt); + mysql_stmt_close(stmt); + return OK; +} + struct my_tests_st my_tests[] = { + {"test_cursor", test_cursor, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"conc_218", conc_218, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"conc_212", conc_212, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"conc_213", conc_213, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, diff --git a/unittest/libmariadb/misc.c b/unittest/libmariadb/misc.c index 28d33c0d..0b11917e 100644 --- a/unittest/libmariadb/misc.c +++ b/unittest/libmariadb/misc.c @@ -265,7 +265,7 @@ static int test_frm_bug(MYSQL *mysql) sprintf(test_frm, "%s/%s/test_frm_bug.frm", data_dir, schema); - if (!(test_file= fopen(test_frm, "rw"))) + if (!(test_file= fopen(test_frm, "w"))) { mysql_stmt_close(stmt); diag("Can't write to file %s -> SKIP", test_frm); diff --git a/unittest/libmariadb/ps.c b/unittest/libmariadb/ps.c index 169f722a..06bbe7c5 100644 --- a/unittest/libmariadb/ps.c +++ b/unittest/libmariadb/ps.c @@ -4974,7 +4974,106 @@ static int test_bit2tiny(MYSQL *mysql) return OK; } +static int test_reexecute(MYSQL *mysql) +{ + MYSQL_STMT *stmt; + MYSQL_BIND ps_params[3]; /* input parameter buffers */ + int int_data[3]; /* input/output values */ + int rc; + + /* set up stored procedure */ + rc = mysql_query(mysql, "DROP PROCEDURE IF EXISTS p1"); + check_mysql_rc(rc, mysql); + + rc = mysql_query(mysql, + "CREATE PROCEDURE p1(" + " IN p_in INT, " + " OUT p_out INT, " + " INOUT p_inout INT) " + "BEGIN " + " SELECT p_in, p_out, p_inout; " + " SET p_in = 100, p_out = 200, p_inout = 300; " + " SELECT p_in, p_out, p_inout; " + "END"); + check_mysql_rc(rc, mysql); + + /* initialize and prepare CALL statement with parameter placeholders */ + stmt = mysql_stmt_init(mysql); + if (!stmt) + { + diag("Could not initialize statement"); + exit(1); + } + rc = mysql_stmt_prepare(stmt, "CALL p1(?, ?, ?)", 16); + check_stmt_rc(rc, stmt); + + /* initialize parameters: p_in, p_out, p_inout (all INT) */ + memset(ps_params, 0, sizeof (ps_params)); + + ps_params[0].buffer_type = MYSQL_TYPE_LONG; + ps_params[0].buffer = (char *) &int_data[0]; + ps_params[0].length = 0; + ps_params[0].is_null = 0; + + ps_params[1].buffer_type = MYSQL_TYPE_LONG; + ps_params[1].buffer = (char *) &int_data[1]; + ps_params[1].length = 0; + ps_params[1].is_null = 0; + + ps_params[2].buffer_type = MYSQL_TYPE_LONG; + ps_params[2].buffer = (char *) &int_data[2]; + ps_params[2].length = 0; + ps_params[2].is_null = 0; + + /* bind parameters */ + rc = mysql_stmt_bind_param(stmt, ps_params); + check_stmt_rc(rc, stmt); + + /* assign values to parameters and execute statement */ + int_data[0]= 10; /* p_in */ + int_data[1]= 20; /* p_out */ + int_data[2]= 30; /* p_inout */ + + rc = mysql_stmt_execute(stmt); + check_stmt_rc(rc, stmt); + + rc= mysql_stmt_execute(stmt); + check_stmt_rc(rc, stmt); + + mysql_stmt_close(stmt); + + rc = mysql_query(mysql, "DROP PROCEDURE IF EXISTS p1"); + check_mysql_rc(rc, mysql); + return OK; +} + +static int test_prepare_error(MYSQL *mysql) +{ + MYSQL_STMT *stmt= mysql_stmt_init(mysql); + int rc; + + rc= mysql_stmt_prepare(stmt, "SELECT 1 FROM tbl_not_exists", -1); + FAIL_IF(!rc, "Expected error"); + + rc= mysql_stmt_reset(stmt); + check_stmt_rc(rc, stmt); + + rc= mysql_stmt_prepare(stmt, "SELECT 1 FROM tbl_not_exists", -1); + FAIL_IF(!rc, "Expected error"); + + rc= mysql_stmt_reset(stmt); + check_stmt_rc(rc, stmt); + + rc= mysql_stmt_prepare(stmt, "SET @a:=1", -1); + check_stmt_rc(rc, stmt); + + mysql_stmt_close(stmt); + return OK; +} + struct my_tests_st my_tests[] = { + {"test_prepare_error", test_prepare_error, TEST_CONNECTION_NEW, 0, NULL, NULL}, + {"test_reexecute", test_reexecute, TEST_CONNECTION_NEW, 0, NULL, NULL}, {"test_bit2tiny", test_bit2tiny, TEST_CONNECTION_NEW, 0, NULL, NULL}, {"test_conc97", test_conc97, TEST_CONNECTION_NEW, 0, NULL, NULL}, {"test_conc83", test_conc83, TEST_CONNECTION_NONE, 0, NULL, NULL},