diff --git a/libmariadb/mariadb_stmt.c b/libmariadb/mariadb_stmt.c index 79c23353..2018cffe 100644 --- a/libmariadb/mariadb_stmt.c +++ b/libmariadb/mariadb_stmt.c @@ -421,6 +421,25 @@ unsigned char *mysql_net_store_length(unsigned char *packet, size_t length) return packet + 8; } +static long ma_get_length(MYSQL_STMT *stmt, unsigned int param_nr, unsigned long row_nr) +{ + if (!stmt->params[param_nr].length) + return 0; + if (stmt->row_size) + return *(long *)((char *)stmt->params[param_nr].length + row_nr * stmt->row_size); + else + return stmt->params[param_nr].length[row_nr]; +} + +static char ma_get_indicator(MYSQL_STMT *stmt, unsigned int param_nr, unsigned long row_nr) +{ + if (!stmt->params[param_nr].u.indicator) + return 0; + if (stmt->row_size) + return *((char *)stmt->params[param_nr].u.indicator + (row_nr * stmt->row_size)); + return stmt->params[param_nr].u.indicator[row_nr]; +} + static void *ma_get_buffer_offset(MYSQL_STMT *stmt, enum enum_field_types type, void *buffer, unsigned long row_nr) { @@ -441,6 +460,8 @@ int store_param(MYSQL_STMT *stmt, int column, unsigned char **p, unsigned long r { void *buf= ma_get_buffer_offset(stmt, stmt->params[column].buffer_type, stmt->params[column].buffer, row_nr); + char indicator= ma_get_indicator(stmt, column, row_nr); + switch (stmt->params[column].buffer_type) { case MYSQL_TYPE_TINY: int1store(*p, (*(uchar *)buf)); @@ -558,16 +579,15 @@ int store_param(MYSQL_STMT *stmt, int column, unsigned char **p, unsigned long r ulong len; /* to is after p. The latter hasn't been moved */ uchar *to; - void *buf= ma_get_buffer_offset(stmt, stmt->params[column].buffer_type, - stmt->params[column].buffer, row_nr); - if (stmt->row_size) - len= *stmt->params[column].length; + if (indicator == STMT_INDICATOR_NTS) + len= -1; else - len= (ulong)STMT_NUM_OFS(long, stmt->params[column].length, row_nr); + len= ma_get_length(stmt, column, row_nr); - if ((long)len == STMT_INDICATOR_NTS) + if (len == (ulong)-1) len= (ulong)strlen((char *)buf); + to = mysql_net_store_length(*p, len); if (len) @@ -717,10 +737,8 @@ unsigned char* mysql_stmt_execute_generate_request(MYSQL_STMT *stmt, size_t *req { indicator= STMT_INDICATOR_NULL; } - else if (stmt->row_size) - indicator= *(char *)(stmt->params[i].u.indicator + j * stmt->row_size); else - indicator= stmt->params[i].u.indicator[j]; + indicator= ma_get_indicator(stmt, i, j); /* check if we need to send data */ if (indicator == STMT_INDICATOR_NULL || indicator == STMT_INDICATOR_DEFAULT) @@ -782,9 +800,11 @@ unsigned char* mysql_stmt_execute_generate_request(MYSQL_STMT *stmt, size_t *req goto mem_error; p= start + offset; } - if ((stmt->params[i].is_null && *stmt->params[i].is_null) || + + if (indicator != STMT_INDICATOR_DEFAULT && + ((stmt->params[i].is_null && *stmt->params[i].is_null) || stmt->params[i].buffer_type == MYSQL_TYPE_NULL || - !stmt->params[i].buffer) + !stmt->params[i].buffer)) { has_data= FALSE; if (!stmt->array_size) @@ -794,7 +814,7 @@ unsigned char* mysql_stmt_execute_generate_request(MYSQL_STMT *stmt, size_t *req } if (bulk_supported && (indicator || stmt->params[i].u.indicator)) { - int1store(p, indicator); + int1store(p, indicator > 0 ? indicator : 0); p++; } if (has_data) diff --git a/unittest/libmariadb/bulk1.c b/unittest/libmariadb/bulk1.c index 05bc9490..246e2ec5 100644 --- a/unittest/libmariadb/bulk1.c +++ b/unittest/libmariadb/bulk1.c @@ -186,19 +186,19 @@ static int bulk3(MYSQL *mysql) { struct st_bulk3 { char char_value[20]; - uchar indicator; + unsigned long length; int int_value; }; - struct st_bulk3 val[]= {{"Row 1", STMT_INDICATOR_NTS, 1}, - {"Row 2", STMT_INDICATOR_NTS, 2}, - {"Row 3", STMT_INDICATOR_NTS, 3}}; + struct st_bulk3 val[3]= {{"Row 1", 5, 1}, + {"Row 02", 6, 2}, + {"Row 003", 7, 3}}; int rc; MYSQL_BIND bind[2]; MYSQL_STMT *stmt= mysql_stmt_init(mysql); size_t row_size= sizeof(struct st_bulk3); int array_size= 3; - ulong length= -1; + int i; if (!bulk_enabled) return SKIP; @@ -219,7 +219,7 @@ static int bulk3(MYSQL *mysql) bind[0].buffer_type= MYSQL_TYPE_STRING; bind[0].buffer= &val[0].char_value; - bind[0].length= &length; + bind[0].length= &val[0].length; bind[1].buffer_type= MYSQL_TYPE_LONG; bind[1].buffer= &val[0].int_value; @@ -232,6 +232,66 @@ static int bulk3(MYSQL *mysql) return OK; } +static int bulk4(MYSQL *mysql) +{ + struct st_bulk4 { + char char_value[20]; + char indicator1; + int int_value; + char indicator2; + }; + + struct st_bulk4 val[]= {{"Row 1", STMT_INDICATOR_NTS, 0, STMT_INDICATOR_DEFAULT}, + {"Row 2", STMT_INDICATOR_NTS, 0, STMT_INDICATOR_DEFAULT}, + {"Row 3", STMT_INDICATOR_NTS, 0, STMT_INDICATOR_DEFAULT}}; + int rc; + MYSQL_BIND bind[2]; + MYSQL_RES *res; + MYSQL_STMT *stmt= mysql_stmt_init(mysql); + size_t row_size= sizeof(struct st_bulk4); + int array_size= 3; + unsigned long lengths[3]= {-1, -1, -1}; + + if (!bulk_enabled) + return SKIP; + rc= mysql_query(mysql, "DROP TABLE IF EXISTS bulk4"); + check_mysql_rc(rc,mysql); + rc= mysql_query(mysql, "CREATE TABLE bulk4 (name varchar(20), row int not null default 3)"); + check_mysql_rc(rc,mysql); + + rc= mysql_stmt_prepare(stmt, "INSERT INTO bulk4 VALUES (?,?)", -1); + check_stmt_rc(rc, stmt); + + memset(bind, 0, sizeof(MYSQL_BIND)*2); + + rc= mysql_stmt_attr_set(stmt, STMT_ATTR_ARRAY_SIZE, &array_size); + check_stmt_rc(rc, stmt); + rc= mysql_stmt_attr_set(stmt, STMT_ATTR_ROW_SIZE, &row_size); + check_stmt_rc(rc, stmt); + + bind[0].buffer_type= MYSQL_TYPE_STRING; + bind[0].u.indicator= &val[0].indicator1; + bind[0].buffer= &val[0].char_value; + bind[0].length= lengths; + bind[1].buffer_type= MYSQL_TYPE_LONG; + bind[1].u.indicator= &val[0].indicator2; + + rc= mysql_stmt_bind_param(stmt, bind); + check_stmt_rc(rc, stmt); + rc= mysql_stmt_execute(stmt); + check_stmt_rc(rc, stmt); + + mysql_stmt_close(stmt); + + rc= mysql_query(mysql, "SELECT * FROM bulk4 WHERE row=3"); + check_mysql_rc(rc, mysql); + res= mysql_store_result(mysql); + rc= mysql_num_rows(res); + mysql_free_result(res); + FAIL_IF(rc != 3, "expected 3 rows"); + return OK; +} + static int bulk_null(MYSQL *mysql) { MYSQL_STMT *stmt= mysql_stmt_init(mysql); @@ -282,6 +342,7 @@ struct my_tests_st my_tests[] = { {"bulk1", bulk1, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"bulk2", bulk2, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"bulk3", bulk3, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, + {"bulk4", bulk4, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {"bulk_null", bulk_null, TEST_CONNECTION_DEFAULT, 0, NULL, NULL}, {NULL, NULL, 0, 0, NULL, NULL} };