diff --git a/mysql-test/main/mysqld--help.result b/mysql-test/main/mysqld--help.result index 30ff9a35953..6bd8f53d8c5 100644 --- a/mysql-test/main/mysqld--help.result +++ b/mysql-test/main/mysqld--help.result @@ -412,6 +412,11 @@ The following specify which files/extra groups are read (specified before remain height-balanced, DOUBLE_PREC_HB - double precision height-balanced, JSON_HB - height-balanced, stored as JSON + --hnsw-ef-constructor + hnsw_ef_constructor + --hnsw-ef-search hnsw_ef_search + --hnsw-max-connection-per-layer + hnsw_max_connection_per_layer --host-cache-size=# How many host names should be cached to avoid resolving (Automatically configured unless set explicitly) --idle-readonly-transaction-timeout=# @@ -1732,6 +1737,9 @@ gtid-strict-mode FALSE help TRUE histogram-size 254 histogram-type JSON_HB +hnsw-ef-constructor 10 +hnsw-ef-search 10 +hnsw-max-connection-per-layer 50 host-cache-size 279 idle-readonly-transaction-timeout 0 idle-transaction-timeout 0 diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result index 76fe78b9b55..108877a569f 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result +++ b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result @@ -1432,6 +1432,36 @@ NUMERIC_BLOCK_SIZE NULL ENUM_VALUE_LIST SINGLE_PREC_HB,DOUBLE_PREC_HB,JSON_HB READ_ONLY NO COMMAND_LINE_ARGUMENT REQUIRED +VARIABLE_NAME HNSW_EF_CONSTRUCTOR +VARIABLE_SCOPE SESSION +VARIABLE_TYPE INT UNSIGNED +VARIABLE_COMMENT hnsw_ef_constructor +NUMERIC_MIN_VALUE 0 +NUMERIC_MAX_VALUE 4294967295 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT NONE +VARIABLE_NAME HNSW_EF_SEARCH +VARIABLE_SCOPE SESSION +VARIABLE_TYPE INT UNSIGNED +VARIABLE_COMMENT hnsw_ef_search +NUMERIC_MIN_VALUE 0 +NUMERIC_MAX_VALUE 4294967295 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT NONE +VARIABLE_NAME HNSW_MAX_CONNECTION_PER_LAYER +VARIABLE_SCOPE SESSION +VARIABLE_TYPE INT UNSIGNED +VARIABLE_COMMENT hnsw_max_connection_per_layer +NUMERIC_MIN_VALUE 0 +NUMERIC_MAX_VALUE 4294967295 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT NONE VARIABLE_NAME HOSTNAME VARIABLE_SCOPE GLOBAL VARIABLE_TYPE VARCHAR diff --git a/sql/item.h b/sql/item.h index 28d7727a540..c7331391508 100644 --- a/sql/item.h +++ b/sql/item.h @@ -6634,7 +6634,6 @@ public: #include "item_subselect.h" #include "item_xmlfunc.h" #include "item_jsonfunc.h" -#include "item_vectorfunc.h" #include "item_create.h" #include "item_vers.h" #endif diff --git a/sql/item_create.cc b/sql/item_create.cc index 7cd85a06007..ba7b5c05bdd 100644 --- a/sql/item_create.cc +++ b/sql/item_create.cc @@ -36,6 +36,7 @@ #include "sp.h" #include "sql_time.h" #include "sql_type_geom.h" +#include "item_vectorfunc.h" #include diff --git a/sql/item_vectorfunc.cc b/sql/item_vectorfunc.cc index 6ff51be9fe4..6e9e23779a0 100644 --- a/sql/item_vectorfunc.cc +++ b/sql/item_vectorfunc.cc @@ -23,6 +23,7 @@ #include #include "item.h" +#include "item_vectorfunc.h" key_map Item_func_vec_distance::part_of_sortkey() const { @@ -48,8 +49,18 @@ double Item_func_vec_distance::val_real() return 0; float *v1= (float*)r1->ptr(); float *v2= (float*)r2->ptr(); + return euclidean_vec_distance(v1, v2, (r1->length()) / sizeof(float)); +} + +double euclidean_vec_distance(float *v1, float *v2, size_t v_len) +{ + float *p1= v1; + float *p2= v2; double d= 0; - for (uint i=0; i < r1->length() / sizeof(float); i++) - d+= (v1[i] - v2[i])*(v1[i] - v2[i]); + for (size_t i= 0; i < v_len; p1++, p2++, i++) + { + float dist= *p1 - *p2; + d+= dist * dist; + } return sqrt(d); } diff --git a/sql/item_vectorfunc.h b/sql/item_vectorfunc.h index 0adff0c0580..731f0315245 100644 --- a/sql/item_vectorfunc.h +++ b/sql/item_vectorfunc.h @@ -17,6 +17,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */ /* This file defines all vector functions */ +#include +#include "item.h" #include "lex_string.h" #include "item_func.h" @@ -34,6 +36,7 @@ class Item_func_vec_distance: public Item_real_func { return check_argument_types_or_binary(NULL, 0, arg_count); } + public: Item_func_vec_distance(THD *thd, Item *a, Item *b) :Item_real_func(thd, a, b) {} @@ -51,6 +54,9 @@ public: key_map part_of_sortkey() const override; Item *do_get_copy(THD *thd) const override { return get_item_copy(thd, this); } + virtual ~Item_func_vec_distance() {}; }; + +double euclidean_vec_distance(float *v1, float *v2, size_t v_len); #endif diff --git a/sql/sql_base.cc b/sql/sql_base.cc index 49e90fc2c2a..56b611fc816 100644 --- a/sql/sql_base.cc +++ b/sql/sql_base.cc @@ -9883,7 +9883,7 @@ int TABLE::hlindex_open(uint nr) mysql_mutex_unlock(&s->LOCK_share); TABLE *table= (TABLE*)alloc_root(&mem_root, sizeof(*table)); if (!table || - open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, 0, + open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, EXTRA_RECORD, in_use->open_options, table, 0)) return 1; hlindex= table; @@ -9938,7 +9938,7 @@ int TABLE::hlindex_read_first(uint nr, Item *item, ulonglong limit) DBUG_ASSERT(hlindex->in_use == in_use); - return mhnsw_read_first(this, item, limit); + return mhnsw_read_first(this, key_info + s->keys, item, limit); } int TABLE::hlindex_read_next() diff --git a/sql/sql_class.h b/sql/sql_class.h index 2b6ea6b23b7..609a18cac54 100644 --- a/sql/sql_class.h +++ b/sql/sql_class.h @@ -922,6 +922,11 @@ typedef struct system_variables my_bool binlog_alter_two_phase; Charset_collation_map_st character_set_collations; + + /* Temporary for HNSW tests */ + uint hnsw_max_connection_per_layer; + uint hnsw_ef_constructor; + uint hnsw_ef_search; } SV; /** diff --git a/sql/sys_vars.cc b/sql/sys_vars.cc index dc4ecef9b50..88d380f4b4d 100644 --- a/sql/sys_vars.cc +++ b/sql/sys_vars.cc @@ -7447,3 +7447,23 @@ static Sys_var_ulonglong Sys_binlog_large_commit_threshold( // Allow a smaller minimum value for debug builds to help with testing VALID_RANGE(IF_DBUG(100, 10240) * 1024, ULLONG_MAX), DEFAULT(128 * 1024 * 1024), BLOCK_SIZE(1)); + +/* Temporary for HNSW tests */ +static Sys_var_uint Sys_hnsw_ef_search( + "hnsw_ef_search", + "hnsw_ef_search", + SESSION_VAR(hnsw_ef_search), CMD_LINE(NO_ARG), + VALID_RANGE(0, UINT_MAX), DEFAULT(10), + BLOCK_SIZE(1)); +static Sys_var_uint Sys_hnsw_ef_constructor( + "hnsw_ef_constructor", + "hnsw_ef_constructor", + SESSION_VAR(hnsw_ef_constructor), CMD_LINE(NO_ARG), + VALID_RANGE(0, UINT_MAX), DEFAULT(10), + BLOCK_SIZE(1)); +static Sys_var_uint Sys_hnsw_max_connection_per_layer( + "hnsw_max_connection_per_layer", + "hnsw_max_connection_per_layer", + SESSION_VAR(hnsw_max_connection_per_layer), CMD_LINE(NO_ARG), + VALID_RANGE(0, UINT_MAX), DEFAULT(50), + BLOCK_SIZE(1)); diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 9fe5378f177..623500a7fe4 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -14,196 +14,814 @@ along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */ +#include #include #include "vector_mhnsw.h" + #include "field.h" +#include "hash.h" #include "item.h" +#include "item_vectorfunc.h" +#include "key.h" +#include "my_base.h" +#include "mysql/psi/psi_base.h" #include "sql_queue.h" -#include + +#define HNSW_MAX_M 10000 const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\ CREATE TABLE i ( \ + layer int not null, \ src varbinary(255) not null, \ - dst varbinary(255) not null, \ - index (src)) \ + neighbors varbinary(10000) not null, \ + index (layer, src)) \ ")}; -static void store_ref(TABLE *t, handler *h, uint n) + +class FVectorRef { - t->hlindex->field[n]->store((char*)h->ref, h->ref_length, &my_charset_bin); +public: + // Shallow ref copy. Used for other ref lookups in HashSet + FVectorRef(uchar *ref, size_t ref_len): ref{ref}, ref_len{ref_len} {} + virtual ~FVectorRef() {} + + static const uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool) + { + *key_len= elem->ref_len; + return elem->ref; + } + + static void free_vector(void *elem) + { + delete (FVectorRef *)elem; + } + + size_t get_ref_len() const { return ref_len; } + const uchar* get_ref() const { return ref; } + +protected: + FVectorRef() = default; + uchar *ref; + size_t ref_len; +}; + +class FVector; + +Hash_set all_vector_set( + PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + +Hash_set all_vector_ref_set( + PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + + +class FVector: public FVectorRef +{ +private: + float *vec; + size_t vec_len; +public: + FVector(): vec(nullptr), vec_len(0) {} + ~FVector() { my_free(this->ref); } + + bool init(const uchar *ref, size_t ref_len, + const float *vec, size_t vec_len) + { + this->ref= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, + ref_len + vec_len * sizeof(float), + MYF(0)); + if (!this->ref) + return true; + + this->vec= reinterpret_cast(this->ref + ref_len); + + memcpy(this->ref, ref, ref_len); + memcpy(this->vec, vec, vec_len * sizeof(float)); + + this->ref_len= ref_len; + this->vec_len= vec_len; + return false; + } + + size_t get_vec_len() const { return vec_len; } + const float* get_vec() const { return vec; } + + double distance_to(const FVector &other) const + { + DBUG_ASSERT(other.vec_len == vec_len); + return euclidean_vec_distance(vec, other.vec, vec_len); + } + + static FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len) + { + FVectorRef tmp{(uchar*)ref, ref_len}; + FVectorRef *v= all_vector_ref_set.find(&tmp); + if (v) + return v; + + // TODO(cvicentiu) memory management. + uchar *buf= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, ref_len, MYF(0)); + memcpy(buf, ref, ref_len); + v= new FVectorRef{buf, ref_len}; + all_vector_ref_set.insert(v); + return v; + } + + static FVector *get_fvector_from_source(TABLE *source, + Field *vect_field, + const FVectorRef &ref) + { + + FVectorRef *v= all_vector_set.find(&ref); + if (v) + return (FVector *)v; + + FVector *new_vector= new FVector; + if (!new_vector) + return nullptr; + + source->file->ha_rnd_pos(source->record[0], + const_cast(ref.get_ref())); + + String buf, *vec; + vec= vect_field->val_str(&buf); + + // TODO(cvicentiu) error checking + new_vector->init(ref.get_ref(), ref.get_ref_len(), + reinterpret_cast(vec->ptr()), + vec->length() / sizeof(float)); + + all_vector_set.insert(new_vector); + + return new_vector; + } +}; + + + + +static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b) +{ + double a_dist= reference->distance_to(*a); + double b_dist= reference->distance_to(*b); + + if (a_dist < b_dist) + return -1; + if (a_dist > b_dist) + return 1; + return 0; } +const bool KEEP_PRUNED_CONNECTIONS=true; +const bool EXTEND_CANDIDATES=true; + +static bool get_neighbours(TABLE *graph, + size_t layer_number, + const FVectorRef &source_node, + List *neighbours); + +static bool select_neighbours(TABLE *source, TABLE *graph, + Field *vect_field, + size_t layer_number, + const FVector &target, + const List &candidates, + size_t max_neighbour_connections, + List *neighbours) +{ + /* + TODO: If the input neighbours list is already sorted in search_layer, then + no need to do additional queue build steps here. + */ + + Hash_set visited(PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + + Queue pq; + Queue pq_discard; + Queue best; + + // TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size. + // This should not be fixed. + pq.init(10000, 0, cmp_vec, &target); + pq_discard.init(10000, 0, cmp_vec, &target); + best.init(max_neighbour_connections, true, cmp_vec, &target); + + // TODO(cvicentiu) error checking. + for (const FVectorRef &candidate : candidates) + { + pq.push(FVector::get_fvector_from_source(source, vect_field, candidate)); + visited.insert(&candidate); + } + + + if (EXTEND_CANDIDATES) + { + for (const FVectorRef &candidate : candidates) + { + List candidate_neighbours; + get_neighbours(graph, layer_number, candidate, &candidate_neighbours); + for (const FVectorRef &extra_candidate : candidate_neighbours) + { + if (visited.find(&extra_candidate)) + continue; + visited.insert(&extra_candidate); + pq.push(FVector::get_fvector_from_source(source, + vect_field, + extra_candidate)); + } + } + } + + DBUG_ASSERT(pq.elements()); + best.push(pq.pop()); + + double best_top = best.top()->distance_to(target); + while (pq.elements() && best.elements() < max_neighbour_connections) + { + const FVector *vec= pq.pop(); + double cur_dist = vec->distance_to(target); + // TODO(cvicentiu) best distance can be cached. + if (cur_dist < best_top) { + + best.push(vec); + best_top = cur_dist; + } + else + pq_discard.push(vec); + } + + if (KEEP_PRUNED_CONNECTIONS) + { + while (pq_discard.elements() && + best.elements() < max_neighbour_connections) + { + best.push(pq_discard.pop()); + } + } + + DBUG_ASSERT(best.elements() <= max_neighbour_connections); + while (best.elements()) { + neighbours->push_front(best.pop()); + } + + return false; +} + +//static bool select_neighbours(TABLE *source, TABLE *graph, +// Field *vect_field, +// size_t layer_number, +// const FVector &target, +// const List &candidates, +// size_t max_neighbour_connections, +// List *neighbours) +//{ +// /* +// TODO: If the input neighbours list is already sorted in search_layer, then +// no need to do additional queue build steps here. +// */ +// +// Queue pq; +// pq.init(candidates.elements, 0, 0, cmp_vec, &target); +// +// // TODO(cvicentiu) error checking. +// for (const FVectorRef &candidate : candidates) +// pq.push(FVector::get_fvector_from_source(source, vect_field, candidate)); +// +// for (size_t i = 0; i < max_neighbour_connections; i++) +// { +// if (!pq.elements()) +// break; +// neighbours->push_back(pq.pop()); +// } +// +// return false; +//} + + +static void dbug_print_vec_ref(const char *prefix, + uint layer, + const FVectorRef &ref) +{ +#ifndef DBUG_OFF + // TODO(cvicentiu) disable this in release build. + char *ref_str= (char *)alloca(ref.get_ref_len() * 2 + 1); + DBUG_ASSERT(ref_str); + char *ptr= ref_str; + for (size_t i = 0; i < ref.get_ref_len(); ptr += 2, i++) + { + snprintf(ptr, 3, "%02x", ref.get_ref()[i]); + } + DBUG_PRINT("VECTOR", ("%s %u %s", prefix, layer, ref_str)); +#endif +} + +static void dbug_print_vec_neigh(uint layer, + const List &neighbors) +{ +#ifndef DBUG_OFF + DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements)); + for (const FVectorRef& ref : neighbors) + { + dbug_print_vec_ref("NEIGH: ", layer, ref); + } +#endif +} + +static bool write_neighbours(TABLE *graph, + size_t layer_number, + const FVectorRef &source_node, + const List &new_neighbours) +{ + DBUG_ASSERT(new_neighbours.elements <= HNSW_MAX_M); + + + size_t total_size= sizeof(uint16_t) + + new_neighbours.elements * source_node.get_ref_len(); + + // Allocate memory for the struct and the flexible array member + char *neighbor_array_bytes= static_cast(alloca(total_size)); + + DBUG_ASSERT(new_neighbours.elements <= INT16_MAX); + *(uint16_t *) neighbor_array_bytes= new_neighbours.elements; + char *pos= neighbor_array_bytes + sizeof(uint16_t); + for (const auto &node: new_neighbours) + { + DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len()); + memcpy(pos, node.get_ref(), node.get_ref_len()); + pos+= node.get_ref_len(); + } + + graph->field[0]->store(layer_number); + graph->field[1]->store_binary( + reinterpret_cast(source_node.get_ref()), + source_node.get_ref_len()); + graph->field[2]->set_null(); + + + uchar *key= (uchar*)alloca(graph->key_info->key_length); + key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); + + int err= graph->file->ha_index_read_map(graph->record[1], key, + HA_WHOLE_KEY, + HA_READ_KEY_EXACT); + + // no record + if (err == HA_ERR_KEY_NOT_FOUND) + { + dbug_print_vec_ref("INSERT ", layer_number, source_node); + graph->field[2]->store_binary(neighbor_array_bytes, total_size); + graph->file->ha_write_row(graph->record[0]); + return false; + } + dbug_print_vec_ref("UPDATE ", layer_number, source_node); + dbug_print_vec_neigh(layer_number, new_neighbours); + + graph->field[2]->store_binary(neighbor_array_bytes, total_size); + graph->file->ha_update_row(graph->record[1], graph->record[0]); + return false; +} + + +static bool get_neighbours(TABLE *graph, + size_t layer_number, + const FVectorRef &source_node, + List *neighbours) +{ + // TODO(cvicentiu) This allocation need not happen in this function. + uchar *key= (uchar*)alloca(graph->key_info->key_length); + + graph->field[0]->store(layer_number); + graph->field[1]->store_binary( + reinterpret_cast(source_node.get_ref()), + source_node.get_ref_len()); + graph->field[2]->set_null(); + key_copy(key, graph->record[0], + graph->key_info, graph->key_info->key_length); + if ((graph->file->ha_index_read_map(graph->record[0], key, + HA_WHOLE_KEY, + HA_READ_KEY_EXACT))) + return true; + + //TODO This does two memcpys, one should use str's buffer. + String strbuf; + String *str= graph->field[2]->val_str(&strbuf); + + // All ref should have same length + uint ref_length= source_node.get_ref_len(); + + const uchar *neigh_arr_bytes= reinterpret_cast(str->ptr()); + uint16_t number_of_neighbours= + *reinterpret_cast(neigh_arr_bytes); + if (number_of_neighbours != (str->length() - sizeof(uint16_t)) / ref_length) + { + /* + neighbours number does not match the data length, + should not happen, possible corrupted HNSW index + */ + DBUG_ASSERT(0); // TODO(cvicentiu) remove this after testing. + return true; + } + + const uchar *pos = neigh_arr_bytes + sizeof(uint16_t); + for (uint16_t i= 0; i < number_of_neighbours; i++) + { + neighbours->push_back(FVector::get_fvector_ref(pos, ref_length)); + pos+= ref_length; + } + + return false; +} + + +static bool update_second_degree_neighbors(TABLE *source, + Field *vec_field, + TABLE *graph, + size_t layer_number, + uint max_neighbours, + const FVectorRef &source_node, + const List &neighbours) +{ + //dbug_print_vec_ref("Updating second degree neighbours", layer_number, source_node); + //dbug_print_vec_neigh(layer_number, neighbours); + for (const FVectorRef &neigh: neighbours) + { + List new_neighbours; + get_neighbours(graph, layer_number, neigh, &new_neighbours); + new_neighbours.push_back(&source_node); + write_neighbours(graph, layer_number, neigh, new_neighbours); + } + + for (const FVectorRef &neigh: neighbours) + { + List new_neighbours; + get_neighbours(graph, layer_number, neigh, &new_neighbours); + // TODO(cvicentiu) get_fvector_from_source results must not need to be freed. + FVector *neigh_vec = FVector::get_fvector_from_source(source, vec_field, neigh); + + if (new_neighbours.elements > max_neighbours) + { + // shrink the neighbours + List selected; + select_neighbours(source, graph, vec_field, layer_number, + *neigh_vec, new_neighbours, + max_neighbours, &selected); + write_neighbours(graph, layer_number, neigh, selected); + } + + // release memory + new_neighbours.empty(); + } + + return false; +} + + +static bool update_neighbours(TABLE *source, + TABLE *graph, + Field *vec_field, + size_t layer_number, + uint max_neighbours, + const FVectorRef &source_node, + const List &neighbours) +{ + // 1. update node's neighbours + write_neighbours(graph, layer_number, source_node, neighbours); + // 2. update node's neighbours' neighbours (shrink before update) + update_second_degree_neighbors(source, vec_field, graph, layer_number, + max_neighbours, source_node, neighbours); + return false; +} + + +static bool search_layer(TABLE *source, + TABLE *graph, + Field *vec_field, + const FVector &target, + const List &start_nodes, + uint max_candidates_return, + size_t layer, + List *result) +{ + DBUG_ASSERT(start_nodes.elements > 0); + // Result list must be empty, otherwise there's a risk of memory leak + DBUG_ASSERT(result->elements == 0); + + Queue candidates; + Queue best; + //TODO(cvicentiu) Fix this hash method. + Hash_set visited(PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + + candidates.init(10000, false, cmp_vec, &target); + best.init(max_candidates_return, true, cmp_vec, &target); + + for (const FVectorRef &node : start_nodes) + { + FVector *v= FVector::get_fvector_from_source(source, vec_field, node); + candidates.push(v); + if (best.elements() < max_candidates_return) + best.push(v); + else if (target.distance_to(*v) > target.distance_to(*best.top())) { + best.replace_top(v); + } + visited.insert(v); + dbug_print_vec_ref("INSERTING node in visited: ", layer, node); + } + + double furthest_best = target.distance_to(*best.top()); + while (candidates.elements()) + { + const FVector &cur_vec= *candidates.pop(); + double cur_distance = target.distance_to(cur_vec); + if (cur_distance > furthest_best && best.elements() == max_candidates_return) + { + break; // All possible candidates are worse than what we have. + // Can't get better. + } + + List neighbours; + get_neighbours(graph, layer, cur_vec, &neighbours); + + for (const FVectorRef &neigh: neighbours) + { + if (visited.find(&neigh)) + continue; + + FVector *clone = FVector::get_fvector_from_source(source, vec_field, neigh); + // TODO(cvicentiu) mem ownershipw... + visited.insert(clone); + if (best.elements() < max_candidates_return) + { + candidates.push(clone); + best.push(clone); + furthest_best = target.distance_to(*best.top()); + } + else if (target.distance_to(*clone) < furthest_best) + { + best.replace_top(clone); + candidates.push(clone); + furthest_best = target.distance_to(*best.top()); + } + } + neighbours.empty(); + } + DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements())); + + while (best.elements()) + { + // TODO(cvicentiu) FVector memory leak. + // TODO(cvicentiu) this is n*log(n), we need a queue iterator. + result->push_front(best.pop()); + } + + return false; +} + + +std::mt19937 gen(42); // TODO(cvicentiu) seeded with 42 for now, this should + // use a rnd service + int mhnsw_insert(TABLE *table, KEY *keyinfo) { TABLE *graph= table->hlindex; MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); - Field *field= keyinfo->key_part->field; - String buf, *res= field->val_str(&buf); + Field *vec_field= keyinfo->key_part->field; + String buf, *res= vec_field->val_str(&buf); handler *h= table->file; int err= 0; - dbug_tmp_restore_column_map(&table->read_set, old_map); /* metadata are checked on open */ DBUG_ASSERT(graph); DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); DBUG_ASSERT(keyinfo->usable_key_parts == 1); - DBUG_ASSERT(field->binary()); - DBUG_ASSERT(field->cmp_type() == STRING_RESULT); + DBUG_ASSERT(vec_field->binary()); + DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL - DBUG_ASSERT(h->ref_length <= graph->field[0]->field_length); DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length); + DBUG_ASSERT(h->ref_length <= graph->field[2]->field_length); if (res->length() == 0 || res->length() % 4) return 1; - // let's do every node to every node - h->position(table->record[0]); - graph->field[0]->store(1); - store_ref(table, h, 0); + const double NORMALIZATION_FACTOR = 1 / std::log(1.0 * + table->in_use->variables.hnsw_max_connection_per_layer); - if (h->lookup_handler->ha_rnd_init(1)) - return 1; - while (! ((err= h->lookup_handler->ha_rnd_next(h->lookup_buffer)))) - { - h->lookup_handler->position(h->lookup_buffer); - if (graph->field[0]->cmp(h->lookup_handler->ref) == 0) - continue; - store_ref(table, h->lookup_handler, 1); - if ((err= graph->file->ha_write_row(graph->record[0]))) - break; - } - h->lookup_handler->ha_rnd_end(); + if ((err= h->ha_rnd_init(1))) + return err; - return err == HA_ERR_END_OF_FILE ? 0 : err; -} - -struct Node -{ - float distance; - uchar ref[1000]; -}; - -static int cmp_float(void *, const Node *a, const Node *b) -{ - return a->distance < b->distance ? -1 : a->distance == b->distance ? 0 : 1; -} - -int mhnsw_read_first(TABLE *table, Item *dist, ulonglong limit) -{ - TABLE *graph= table->hlindex; - Queue todo, result; - Node *cur; - String *str, strbuf; - const size_t ref_length= table->file->ref_length; - const size_t element_size= ref_length + sizeof(float); - uchar *key= (uchar*)alloca(ref_length + 32); - Hash_set visited(PSI_INSTRUMENT_MEM, &my_charset_bin, limit, - sizeof(float), ref_length, 0, 0, HASH_UNIQUE); - uint keylen; - int err= 0; - - DBUG_ASSERT(graph); - - if (todo.init(1000, 0, cmp_float)) // XXX + autoextent - return HA_ERR_OUT_OF_MEM; - - if (result.init(limit, 1, cmp_float)) - return HA_ERR_OUT_OF_MEM; if ((err= graph->file->ha_index_init(0, 1))) return err; - SCOPE_EXIT([graph](){ graph->file->ha_index_end(); }); - - // 1. read a start row + longlong max_layer; if ((err= graph->file->ha_index_last(graph->record[0]))) - return err; - - if (!(str= graph->field[0]->val_str(&strbuf))) - return HA_ERR_CRASHED; - - DBUG_ASSERT(str->length() == ref_length); - - cur= (Node*)table->in_use->alloc(element_size); - memcpy(cur->ref, str->ptr(), ref_length); - - if ((err= table->file->ha_rnd_init(0))) - return err; - - if ((err= table->file->ha_rnd_pos(table->record[0], cur->ref))) - return HA_ERR_CRASHED; - - // 2. add it to the todo - cur->distance= dist->val_real(); - if (dist->is_null()) - return HA_ERR_END_OF_FILE; - todo.push(cur); - visited.insert(cur); - - while (todo.elements()) { - // 3. pick the top node from the todo - cur= todo.pop(); - - // 4. add it to the result - if (result.is_full()) + if (err != HA_ERR_END_OF_FILE) { - // 5. if not added, greedy search done - if (cur->distance > result.top()->distance) - break; - result.replace_top(cur); + graph->file->ha_index_end(); + return err; } - else - result.push(cur); + // First insert! + h->position(table->record[0]); + write_neighbours(graph, 0, {h->ref, h->ref_length}, {}); - float threshold= result.is_full() ? result.top()->distance : FLT_MAX; + h->ha_rnd_end(); + graph->file->ha_index_end(); + return 0; // TODO (error during store_link) + } + else + max_layer= graph->field[0]->val_int(); - // 6. add all its [yet unvisited] neighbours to the todo heap - keylen= graph->field[0]->get_key_image(key, ref_length, Field::itRAW); - if ((err= graph->file->ha_index_read_map(graph->record[0], key, 3, - HA_READ_KEY_EXACT))) - return HA_ERR_CRASHED; + FVector target; + h->position(table->record[0]); + // TODO (cvicentiu) Error checking. + target.init(h->ref, h->ref_length, + reinterpret_cast(res->ptr()), + res->length() / sizeof(float)); - do { - if (!(str= graph->field[1]->val_str(&strbuf))) - return HA_ERR_CRASHED; - if (visited.find(str->ptr(), ref_length)) - continue; + std::uniform_real_distribution<> dis(0.0, 1.0); + double new_num= dis(gen); + double log= -std::log(new_num) * NORMALIZATION_FACTOR; + longlong new_node_layer= std::floor(log); - if ((err= table->file->ha_rnd_pos(table->record[0], (uchar*)str->ptr()))) - return HA_ERR_CRASHED; + List start_nodes; - float distance= dist->val_real(); - if (distance > threshold) - continue; + String ref_str, *ref_ptr; + ref_ptr= graph->field[1]->val_str(&ref_str); - cur= (Node*)table->in_use->alloc(element_size); - cur->distance= distance; - memcpy(cur->ref, str->ptr(), ref_length); - todo.push(cur); - visited.insert(cur); - } while (!graph->file->ha_index_next_same(graph->record[0], key, keylen)); - // 7. goto 3 + FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()}; + //FVector *start_node= start_node_ref.get_fvector_from_source(table, vec_field); + + // TODO(cvicentiu) error checking. Also make sure we use a random start node + // in last layer. + start_nodes.push_back(&start_node_ref); + // TODO start_nodes needs to have one element in it. + for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) + { + List candidates; + search_layer(table, graph, vec_field, target, start_nodes, + table->in_use->variables.hnsw_ef_constructor, cur_layer, + &candidates); + start_nodes.empty(); + start_nodes.push_back(candidates.head()); + //candidates.delete_elements(); + //TODO(cvicentiu) memory leak } + for (longlong cur_layer= std::min(max_layer, new_node_layer); + cur_layer >= 0; cur_layer--) + { + List candidates; + List neighbours; + search_layer(table, graph, vec_field, target, start_nodes, + table->in_use->variables.hnsw_ef_constructor, + cur_layer, &candidates); + // release vectors + start_nodes.empty(); + + uint max_neighbours= (cur_layer == 0) ? + table->in_use->variables.hnsw_max_connection_per_layer * 2 + : table->in_use->variables.hnsw_max_connection_per_layer; + + select_neighbours(table, graph, vec_field, cur_layer, + target, candidates, + max_neighbours, &neighbours); + update_neighbours(table, graph, vec_field, cur_layer, max_neighbours, + target, neighbours); + start_nodes= candidates; + } + start_nodes.empty(); + + for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer; + cur_layer++) + { + write_neighbours(graph, cur_layer, target, {}); + } + + + h->ha_rnd_end(); + graph->file->ha_index_end(); + dbug_tmp_restore_column_map(&table->read_set, old_map); + + return err == HA_ERR_END_OF_FILE ? 0 : err; +} + + +int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) +{ + TABLE *graph= table->hlindex; + MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); + // TODO(cvicentiu) onlye one hlindex now. + Field *vec_field= keyinfo->key_part->field; + Item_func_vec_distance *fun= (Item_func_vec_distance *)dist; + String buf, *res= fun->arguments()[1]->val_str(&buf); + handler *h= table->file; + + //TODO(scope_exit) + int err; + if ((err= h->ha_rnd_init(0))) + return err; + + if ((err= graph->file->ha_index_init(0, 1))) + return err; + + h->position(table->record[0]); + FVector target; + target.init(h->ref, + h->ref_length, + reinterpret_cast(res->ptr()), + res->length() / sizeof(float)); + + List candidates; + List start_nodes; + + longlong max_layer; + if ((err= graph->file->ha_index_last(graph->record[0]))) + { + if (err != HA_ERR_END_OF_FILE) + { + graph->file->ha_index_end(); + return err; + } + h->ha_rnd_end(); + graph->file->ha_index_end(); + return 0; // TODO (error during store_link) + } + else + max_layer= graph->field[0]->val_int(); + + String ref_str, *ref_ptr; + ref_ptr= graph->field[1]->val_str(&ref_str); + FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()}; + // TODO(cvicentiu) error checking. Also make sure we use a random start node + // in last layer. + start_nodes.push_back(&start_node_ref); + + ulonglong ef_search= MY_MAX( + table->in_use->variables.hnsw_ef_search, limit); + + for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) + { + search_layer(table, graph, vec_field, target, start_nodes, ef_search, + cur_layer, &candidates); + start_nodes.empty(); + //start_nodes.delete_elements(); + start_nodes.push_back(candidates.head()); + //candidates.delete_elements(); + candidates.empty(); + //TODO(cvicentiu) memleak. + } + + search_layer(table, graph, vec_field, target, start_nodes, + ef_search, 0, &candidates); + // 8. return results - Node **context= (Node**)table->in_use->alloc(sizeof(Node**)*result.elements()+1); + FVectorRef **context= (FVectorRef**)table->in_use->alloc( + sizeof(FVectorRef*) * (limit + 1)); graph->context= context; - Node **ptr= context+result.elements(); - *ptr= 0; - while (result.elements()) - *--ptr= result.pop(); + FVectorRef **ptr= context; + while (limit--) + *ptr++= candidates.pop(); + *ptr= nullptr; - return mhnsw_read_next(table); + err= mhnsw_read_next(table); + graph->file->ha_index_end(); + + // TODO release vectors after query + + dbug_tmp_restore_column_map(&table->read_set, old_map); + return err; } int mhnsw_read_next(TABLE *table) { - Node ***context= (Node***)&table->hlindex->context; - if (**context) - return table->file->ha_rnd_pos(table->record[0], (*(*context)++)->ref); + FVectorRef ***context= (FVectorRef ***)&table->hlindex->context; + FVectorRef *cur_vec= **context; + if (cur_vec) + { + int err= table->file->ha_rnd_pos(table->record[0], + (uchar *)(cur_vec)->get_ref()); + // release vectors + // delete cur_vec; + + (*context)++; + return err; + } return HA_ERR_END_OF_FILE; } diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index 74a0bd8ce42..141512244b7 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -15,10 +15,14 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */ +#include +#include "item.h" +#include "m_string.h" +#include "structs.h" #include "table.h" extern const LEX_CSTRING mhnsw_hlindex_table; int mhnsw_insert(TABLE *table, KEY *keyinfo); -int mhnsw_read_first(TABLE *table, Item *dist, ulonglong limit); +int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit); int mhnsw_read_next(TABLE *table);