diff --git a/include/mariadb_stmt.h b/include/mariadb_stmt.h index b8ee35ad..97b6c298 100644 --- a/include/mariadb_stmt.h +++ b/include/mariadb_stmt.h @@ -56,7 +56,8 @@ enum enum_stmt_attr_type { STMT_ATTR_UPDATE_MAX_LENGTH, STMT_ATTR_CURSOR_TYPE, - STMT_ATTR_PREFETCH_ROWS + STMT_ATTR_PREFETCH_ROWS, + STMT_ATTR_PREBIND_PARAMS=200 }; enum enum_cursor_type @@ -172,7 +173,7 @@ typedef int (*mysql_stmt_fetch_row_func)(MYSQL_STMT *stmt, unsigned char **row) struct st_mysql_stmt { - MA_MEM_ROOT mem_root; + MA_MEM_ROOT mem_root; MYSQL *mysql; unsigned long stmt_id; unsigned long flags;/* cursor is set here */ diff --git a/libmariadb/mariadb_lib.c b/libmariadb/mariadb_lib.c index ca232923..3aeb0c0b 100644 --- a/libmariadb/mariadb_lib.c +++ b/libmariadb/mariadb_lib.c @@ -2567,9 +2567,12 @@ int mariadb_flush_multi_command(MYSQL *mysql) /* reset multi_buff */ mysql->net.extension->mbuff_pos= mysql->net.extension->mbuff; + /* don't read result for mysql_stmt_execute_direct() */ if (!rc) if (mysql->net.extension->mbuff && length > 3 && - (mysql->net.extension->mbuff[3] == COM_STMT_PREPARE || mysql->net.extension->mbuff[3] == COM_STMT_EXECUTE)) + (mysql->net.extension->mbuff[3] == COM_STMT_PREPARE || + mysql->net.extension->mbuff[3] == COM_STMT_EXECUTE || + mysql->net.extension->mbuff[3] == COM_STMT_CLOSE)) return rc; else return mysql->methods->db_read_query_result(mysql); diff --git a/libmariadb/mariadb_stmt.c b/libmariadb/mariadb_stmt.c index e40ca3b1..56afa1d9 100644 --- a/libmariadb/mariadb_stmt.c +++ b/libmariadb/mariadb_stmt.c @@ -721,6 +721,8 @@ my_bool STDCALL mysql_stmt_attr_get(MYSQL_STMT *stmt, enum enum_stmt_attr_type a case STMT_ATTR_PREFETCH_ROWS: *(unsigned long *)value= stmt->prefetch_rows; break; + case STMT_ATTR_PREBIND_PARAMS: + *(unsigned int *)value= stmt->param_count; default: return(1); } @@ -747,6 +749,9 @@ my_bool STDCALL mysql_stmt_attr_set(MYSQL_STMT *stmt, enum enum_stmt_attr_type a else stmt->prefetch_rows= *(long *)value; break; + case STMT_ATTR_PREBIND_PARAMS: + stmt->param_count= *(unsigned int *)value; + break; default: SET_CLIENT_STMT_ERROR(stmt, CR_NOT_IMPLEMENTED, SQLSTATE_UNKNOWN, 0); return(1); @@ -764,26 +769,19 @@ my_bool STDCALL mysql_stmt_bind_param(MYSQL_STMT *stmt, MYSQL_BIND *bind) return(1); } - /* for mariadb_stmt_execute_direct we need to bind parameters in advance: - client has to pass a bind array, where last parameter needs to be set - to buffer type MAX_NO_FIELD_TYPES */ + /* if we want to call mariadb_stmt_execute_direct the number of parameters + is unknown, since we didn't prepare the statement at this point. + Number of parameters needs to be set manually via mysql_stmt_attr_set() + function */ if (stmt->state < MYSQL_STMT_PREPARED && !(mysql->server_capabilities & CLIENT_MYSQL)) { - if (!stmt->params) + if (!stmt->params && stmt->param_count) { - int param_count; - for(param_count= 0; - bind[param_count].buffer_type != MAX_NO_FIELD_TYPES; - param_count++); - stmt->param_count= param_count; - if (stmt->param_count) + if (!(stmt->params= (MYSQL_BIND *)ma_alloc_root(&stmt->mem_root, stmt->param_count * sizeof(MYSQL_BIND)))) { - if (!(stmt->params= (MYSQL_BIND *)ma_alloc_root(&stmt->mem_root, stmt->param_count * sizeof(MYSQL_BIND)))) - { - SET_CLIENT_STMT_ERROR(stmt, CR_OUT_OF_MEMORY, SQLSTATE_UNKNOWN, 0); - return(1); - } + SET_CLIENT_STMT_ERROR(stmt, CR_OUT_OF_MEMORY, SQLSTATE_UNKNOWN, 0); + return(1); } memset(stmt->params, '\0', stmt->param_count * sizeof(MYSQL_BIND)); } @@ -1963,17 +1961,53 @@ int STDCALL mariadb_stmt_execute_direct(MYSQL_STMT *stmt, if (mysql_optionsv(mysql, MARIADB_OPT_COM_MULTI, &multi)) goto fail; + if (!stmt->mysql) + { + SET_CLIENT_STMT_ERROR(stmt, CR_SERVER_LOST, SQLSTATE_UNKNOWN, 0); + return(1); + } + if (length == -1) length= strlen(stmt_str); - if (mysql_stmt_prepare(stmt, stmt_str, length)) + mysql_get_optionv(mysql, MARIADB_OPT_COM_MULTI, &multi); + + /* clear flags */ + CLEAR_CLIENT_STMT_ERROR(stmt); + CLEAR_CLIENT_ERROR(stmt->mysql); + stmt->upsert_status.affected_rows= mysql->affected_rows= (unsigned long long) ~0; + + /* check if we have to clear results */ + if (stmt->state > MYSQL_STMT_INITTED) + { + /* We need to semi-close the prepared statement: + reset stmt and free all buffers and close the statement + on server side. Statment handle will get a new stmt_id */ + char stmt_id[STMT_ID_LENGTH]; + + if (mysql_stmt_internal_reset(stmt, 1)) + goto fail; + + ma_free_root(&stmt->mem_root, MYF(MY_KEEP_PREALLOC)); + ma_free_root(&((MADB_STMT_EXTENSION *)stmt->extension)->fields_ma_alloc_root, MYF(0)); + stmt->field_count= 0; + + int4store(stmt_id, stmt->stmt_id); + if (mysql->methods->db_command(mysql, COM_STMT_CLOSE, stmt_id, + sizeof(stmt_id), 1, stmt)) + goto fail; + } + if (mysql->methods->db_command(mysql, COM_STMT_PREPARE, stmt_str, length, 1, stmt)) goto fail; stmt->state= MYSQL_STMT_PREPARED; - + /* Since we can't determine stmt_id here, we need to set it to -1, so server will know that the + * execute command belongs to previous prepare */ + stmt->stmt_id= -1; if (mysql_stmt_execute(stmt)) goto fail; + /* flush multi buffer */ multi= MARIADB_COM_MULTI_END; if (mysql_optionsv(mysql, MARIADB_OPT_COM_MULTI, &multi)) goto fail; diff --git a/libmariadb/secure/openssl.c b/libmariadb/secure/openssl.c index 8aea5239..2918340e 100644 --- a/libmariadb/secure/openssl.c +++ b/libmariadb/secure/openssl.c @@ -493,8 +493,8 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls) pvio->methods->blocking(pvio, TRUE, 0); SSL_clear(ssl); - SSL_SESSION_set_timeout(SSL_get_session(ssl), - mysql->options.connect_timeout); + /*SSL_SESSION_set_timeout(SSL_get_session(ssl), + mysql->options.connect_timeout); */ SSL_set_fd(ssl, mysql_get_socket(mysql)); if (SSL_connect(ssl) != 1) diff --git a/unittest/libmariadb/CMakeLists.txt b/unittest/libmariadb/CMakeLists.txt index c5da8b41..bd528243 100644 --- a/unittest/libmariadb/CMakeLists.txt +++ b/unittest/libmariadb/CMakeLists.txt @@ -70,6 +70,8 @@ ENDIF() ADD_LIBRARY(ma_getopt ma_getopt.c) +ADD_EXECUTABLE(my_test test.c) +TARGET_LINK_LIBRARIES(my_test mariadbclient) FOREACH(API_TEST ${API_TESTS}) ADD_EXECUTABLE(${API_TEST} ${API_TEST}.c) TARGET_LINK_LIBRARIES(${API_TEST} mytap ma_getopt mariadbclient) diff --git a/unittest/libmariadb/connection.c b/unittest/libmariadb/connection.c index 33f0f691..dbbb7d0c 100644 --- a/unittest/libmariadb/connection.c +++ b/unittest/libmariadb/connection.c @@ -876,7 +876,8 @@ static int test_get_options(MYSQL *my) mysql_options(mysql, options_char[i], char1); char2= NULL; mysql_get_optionv(mysql, options_char[i], (void *)&char2); - FAIL_IF(strcmp(char1, char2), "mysql_get_optionv (char) failed"); + if (options_char[i] != MYSQL_SET_CHARSET_NAME) + FAIL_IF(strcmp(char1, char2), "mysql_get_optionv (char) failed"); } for (i=0; i < 3; i++) diff --git a/unittest/libmariadb/features-10_2.c b/unittest/libmariadb/features-10_2.c index 2ae227dc..dc2639b6 100644 --- a/unittest/libmariadb/features-10_2.c +++ b/unittest/libmariadb/features-10_2.c @@ -72,7 +72,7 @@ static int com_multi_1(MYSQL *mysql) #define repeat1 100 -#define repeat2 10 +#define repeat2 1 static int com_multi_2(MYSQL *mysql) { @@ -171,11 +171,11 @@ static int com_multi_ps1(MYSQL *mysql) static int com_multi_ps2(MYSQL *mysql) { MYSQL_STMT *stmt; - MYSQL_BIND bind[3]; + MYSQL_BIND bind[2]; int intval= 3, rc; int i; char *varval= "com_multi_ps2"; - + unsigned int param_count= 2; if (!have_com_multi) return SKIP; @@ -184,28 +184,79 @@ static int com_multi_ps2(MYSQL *mysql) check_mysql_rc(rc, mysql); rc= mysql_query(mysql, "CREATE TABLE t1 (a int, b varchar(20))"); - memset(&bind, 0, sizeof(MYSQL_BIND) * 3); + memset(&bind, 0, sizeof(MYSQL_BIND) * 2); bind[0].buffer_type= MYSQL_TYPE_SHORT; bind[0].buffer= &intval; bind[1].buffer_type= MYSQL_TYPE_STRING; bind[1].buffer= varval; bind[1].buffer_length= strlen(varval); - bind[2].buffer_type= MAX_NO_FIELD_TYPES; + + stmt= mysql_stmt_init(mysql); + mysql_stmt_attr_set(stmt, STMT_ATTR_PREBIND_PARAMS, ¶m_count); + rc= mysql_stmt_bind_param(stmt, bind); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "INSERT INTO t1 VALUES (?,?)", -1); + check_stmt_rc(rc, stmt); for (i=0; i < 2; i++) { - stmt= mysql_stmt_init(mysql); - rc= mysql_stmt_bind_param(stmt, bind); - check_stmt_rc(rc, stmt); - - rc= mariadb_stmt_execute_direct(stmt, "INSERT INTO t1 VALUES (1,'foo')", -1); + mysql_stmt_execute(stmt); check_stmt_rc(rc, stmt); FAIL_IF(mysql_stmt_affected_rows(stmt) != 1, "expected affected_rows= 1"); FAIL_IF(stmt->stmt_id < 1, "expected statement id > 0"); - - rc= mysql_stmt_close(stmt); - check_mysql_rc(rc, mysql); } + rc= mysql_stmt_close(stmt); + check_mysql_rc(rc, mysql); + + return OK; +} + +static int execute_direct(MYSQL *mysql) +{ + long rc= 0, i= 0; + MYSQL_STMT *stmt; + MYSQL_BIND bind; + unsigned int param_count= 1; + MYSQL_RES *res= NULL; + + stmt= mysql_stmt_init(mysql); + + rc= mariadb_stmt_execute_direct(stmt, "DROP TABLE IF EXISTS t1", -1); + check_stmt_rc(rc, stmt); + + rc= mariadb_stmt_execute_direct(stmt, "CREATE TABLE t1 (a int)", -1); + check_stmt_rc(rc, stmt); + + memset(&bind, 0, sizeof(MYSQL_BIND)); + + bind.buffer= &i; + bind.buffer_type= MYSQL_TYPE_LONG; + bind.buffer_length= sizeof(long); + + mysql_stmt_close(stmt); + stmt= mysql_stmt_init(mysql); + mysql_stmt_attr_set(stmt, STMT_ATTR_PREBIND_PARAMS, ¶m_count); + + rc= mysql_stmt_bind_param(stmt, &bind); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "INSERT INTO t1 VALUES (?)", -1); + check_stmt_rc(rc, stmt); + + for (i=1; i < 1000; i++) + { + rc= mysql_stmt_execute(stmt); + check_stmt_rc(rc, stmt); + } + rc= mysql_stmt_close(stmt); + check_mysql_rc(rc, mysql); + + rc= mysql_query(mysql, "SELECT * FROM t1"); + check_mysql_rc(rc, mysql); + + res= mysql_store_result(mysql); + FAIL_IF(mysql_num_rows(res) != 1000, "Expected 1000 rows"); + + mysql_free_result(res); return OK; } @@ -215,6 +266,7 @@ struct my_tests_st my_tests[] = { {"com_multi_2", com_multi_2, TEST_CONNECTION_NEW, 0, NULL, NULL}, {"com_multi_ps1", com_multi_ps1, TEST_CONNECTION_NEW, 0, NULL, NULL}, {"com_multi_ps2", com_multi_ps2, TEST_CONNECTION_NEW, 0, NULL, NULL}, + {"execute_direct", execute_direct, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {NULL, NULL, 0, 0, NULL, NULL} }; diff --git a/unittest/libmariadb/ps.c b/unittest/libmariadb/ps.c index eb67b5ec..0be3033e 100644 --- a/unittest/libmariadb/ps.c +++ b/unittest/libmariadb/ps.c @@ -74,11 +74,11 @@ static int test_conc83(MYSQL *my) check_stmt_rc(rc, stmt); diag("Ok"); - /* 2. Status is prepared, second prepare should fail */ + /* 2. Status is prepared, execute should fail */ rc= mysql_kill(mysql, mysql_thread_id(mysql)); sleep(2); - rc= mysql_stmt_prepare(stmt, query, -1); + rc= mysql_stmt_execute(stmt); FAIL_IF(!rc, "Error expected"); mysql_stmt_close(stmt); diff --git a/unittest/libmariadb/ps_bugs.c b/unittest/libmariadb/ps_bugs.c index 40a4307b..7a896fac 100644 --- a/unittest/libmariadb/ps_bugs.c +++ b/unittest/libmariadb/ps_bugs.c @@ -4247,7 +4247,94 @@ static int test_conc179(MYSQL *mysql) return OK; } +static int test_conc182(MYSQL *mysql) +{ + MYSQL_STMT *stmt; + int rc; + MYSQL_BIND bind[2]; + char buf1[22]; + MYSQL_RES *result; + MYSQL_ROW row; + + stmt= mysql_stmt_init(mysql); + rc= mariadb_stmt_execute_direct(stmt, "DROP TABLE IF EXISTS t1", -1); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "DROP TABLE IF EXISTS t1", -1); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "SELECT 1", -1); + check_stmt_rc(rc, stmt); + rc= mariadb_stmt_execute_direct(stmt, "SELECT 1", -1); + check_stmt_rc(rc, stmt); + + rc= mysql_stmt_close(stmt); + check_mysql_rc(rc, mysql); + + rc= mysql_query(mysql, "SELECT row_count()"); + result= mysql_store_result(mysql); + row= mysql_fetch_row(result); + diag("buf: %s", row[0]); + mysql_free_result(result); + + + stmt= mysql_stmt_init(mysql); + rc= mysql_stmt_prepare(stmt, "SELECT row_count()", -1); + check_stmt_rc(rc, stmt); + rc= mysql_stmt_execute(stmt); + + memset(bind, 0, 2 * sizeof(MYSQL_BIND)); + bind[0].buffer= &buf1; + bind[0].buffer_length= bind[1].buffer_length= 20; + bind[0].buffer_type= bind[1].buffer_type= MYSQL_TYPE_STRING; + + rc= mysql_stmt_bind_result(stmt, bind); + + while(!mysql_stmt_fetch(stmt)) + diag("b1: %s", buf1); + rc= mysql_stmt_close(stmt); + return OK; +} + +static int test_conc181(MYSQL *mysql) +{ + MYSQL_STMT *stmt; + int rc; + MYSQL_BIND bind; + char *stmt_str= "SELECT a FROM t1"; + float f=1; + my_bool err= 0; + + rc= mysql_query(mysql, "DROP TABLE IF EXISTS t1"); + check_mysql_rc(rc, mysql); + rc= mysql_query(mysql, "CREATE TABLE t1 (a int)"); + check_mysql_rc(rc, mysql); + rc= mysql_query(mysql, "INSERT INTO t1 VALUES(1073741825)"); + check_mysql_rc(rc, mysql); + + stmt= mysql_stmt_init(mysql); + rc= mysql_stmt_prepare(stmt, stmt_str, strlen(stmt_str)); + check_stmt_rc(rc, stmt); + + rc= mysql_stmt_execute(stmt); + check_stmt_rc(rc, stmt); + + memset(&bind, 0, sizeof(MYSQL_BIND)); + bind.buffer= &f; + bind.error= &err; + bind.buffer_type= MYSQL_TYPE_FLOAT; + rc= mysql_stmt_bind_result(stmt, &bind); + check_stmt_rc(rc, stmt); + + rc= mysql_stmt_fetch(stmt); + diag("rc=%d err=%d float=%f, %d", rc, err, f, MYSQL_DATA_TRUNCATED); + + rc= mysql_stmt_close(stmt); + return OK; +} + + struct my_tests_st my_tests[] = { + {"test_conc182", test_conc182, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, + {"test_conc181", test_conc181, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"test_conc179", test_conc179, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"test_conc177", test_conc177, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"test_conc167", test_conc167, TEST_CONNECTION_DEFAULT, 0, NULL, NULL},