From 88839e71a35e7147f963ecb7978eb1a6d70dee2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vicen=C8=9Biu=20Ciorbaru?= Date: Sat, 17 Feb 2024 17:03:30 +0200 Subject: [PATCH] Initial HNSW implementation This commit includes the work done in collaboration with Hugo Wen from Amazon: MDEV-33408 Alter HNSW graph storage and fix memory leak This commit changes the way HNSW graph information is stored in the second table. Instead of storing connections as separate records, it now stores neighbors for each node, leading to significant performance improvements and storage savings. Comparing with the previous approach, the insert speed is 5 times faster, search speed improves by 23%, and storage usage is reduced by 73%, based on ann-benchmark tests with random-xs-20-euclidean and random-s-100-euclidean datasets. Additionally, in previous code, vector objects were not released after use, resulting in excessive memory consumption (over 20GB for building the index with 90,000 records), preventing tests with large datasets. Now ensure that vectors are released appropriately during the insert and search functions. Note there are still some vectors that need to be cleaned up after search query completion. Needs to be addressed in a future commit. All new code of the whole pull request, including one or several files that are either new files or modified ones, are contributed under the BSD-new license. I am contributing on behalf of my employer Amazon Web Services, Inc. As well as the commit: Introduce session variables to manage HNSW index parameters Three variables: hnsw_max_connection_per_layer hnsw_ef_constructor hnsw_ef_search ann-benchmark tool is also updated to support these variables in commit https://github.com/HugoWenTD/ann-benchmarks/commit/e09784e for branch https://github.com/HugoWenTD/ann-benchmarks/tree/mariadb-configurable All new code of the whole pull request, including one or several files that are either new files or modified ones, are contributed under the BSD-new license. I am contributing on behalf of my employer Amazon Web Services, Inc. Co-authored-by: Hugo Wen --- mysql-test/main/mysqld--help.result | 8 + .../r/sysvars_server_notembedded.result | 30 + sql/item.h | 1 - sql/item_create.cc | 1 + sql/item_vectorfunc.cc | 15 +- sql/item_vectorfunc.h | 6 + sql/sql_base.cc | 4 +- sql/sql_class.h | 5 + sql/sys_vars.cc | 20 + sql/vector_mhnsw.cc | 882 +++++++++++++++--- sql/vector_mhnsw.h | 6 +- 11 files changed, 840 insertions(+), 138 deletions(-) diff --git a/mysql-test/main/mysqld--help.result b/mysql-test/main/mysqld--help.result index 30ff9a35953..6bd8f53d8c5 100644 --- a/mysql-test/main/mysqld--help.result +++ b/mysql-test/main/mysqld--help.result @@ -412,6 +412,11 @@ The following specify which files/extra groups are read (specified before remain height-balanced, DOUBLE_PREC_HB - double precision height-balanced, JSON_HB - height-balanced, stored as JSON + --hnsw-ef-constructor + hnsw_ef_constructor + --hnsw-ef-search hnsw_ef_search + --hnsw-max-connection-per-layer + hnsw_max_connection_per_layer --host-cache-size=# How many host names should be cached to avoid resolving (Automatically configured unless set explicitly) --idle-readonly-transaction-timeout=# @@ -1732,6 +1737,9 @@ gtid-strict-mode FALSE help TRUE histogram-size 254 histogram-type JSON_HB +hnsw-ef-constructor 10 +hnsw-ef-search 10 +hnsw-max-connection-per-layer 50 host-cache-size 279 idle-readonly-transaction-timeout 0 idle-transaction-timeout 0 diff --git a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result index 76fe78b9b55..108877a569f 100644 --- a/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result +++ b/mysql-test/suite/sys_vars/r/sysvars_server_notembedded.result @@ -1432,6 +1432,36 @@ NUMERIC_BLOCK_SIZE NULL ENUM_VALUE_LIST SINGLE_PREC_HB,DOUBLE_PREC_HB,JSON_HB READ_ONLY NO COMMAND_LINE_ARGUMENT REQUIRED +VARIABLE_NAME HNSW_EF_CONSTRUCTOR +VARIABLE_SCOPE SESSION +VARIABLE_TYPE INT UNSIGNED +VARIABLE_COMMENT hnsw_ef_constructor +NUMERIC_MIN_VALUE 0 +NUMERIC_MAX_VALUE 4294967295 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT NONE +VARIABLE_NAME HNSW_EF_SEARCH +VARIABLE_SCOPE SESSION +VARIABLE_TYPE INT UNSIGNED +VARIABLE_COMMENT hnsw_ef_search +NUMERIC_MIN_VALUE 0 +NUMERIC_MAX_VALUE 4294967295 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT NONE +VARIABLE_NAME HNSW_MAX_CONNECTION_PER_LAYER +VARIABLE_SCOPE SESSION +VARIABLE_TYPE INT UNSIGNED +VARIABLE_COMMENT hnsw_max_connection_per_layer +NUMERIC_MIN_VALUE 0 +NUMERIC_MAX_VALUE 4294967295 +NUMERIC_BLOCK_SIZE 1 +ENUM_VALUE_LIST NULL +READ_ONLY NO +COMMAND_LINE_ARGUMENT NONE VARIABLE_NAME HOSTNAME VARIABLE_SCOPE GLOBAL VARIABLE_TYPE VARCHAR diff --git a/sql/item.h b/sql/item.h index 28d7727a540..c7331391508 100644 --- a/sql/item.h +++ b/sql/item.h @@ -6634,7 +6634,6 @@ public: #include "item_subselect.h" #include "item_xmlfunc.h" #include "item_jsonfunc.h" -#include "item_vectorfunc.h" #include "item_create.h" #include "item_vers.h" #endif diff --git a/sql/item_create.cc b/sql/item_create.cc index 7cd85a06007..ba7b5c05bdd 100644 --- a/sql/item_create.cc +++ b/sql/item_create.cc @@ -36,6 +36,7 @@ #include "sp.h" #include "sql_time.h" #include "sql_type_geom.h" +#include "item_vectorfunc.h" #include diff --git a/sql/item_vectorfunc.cc b/sql/item_vectorfunc.cc index 6ff51be9fe4..6e9e23779a0 100644 --- a/sql/item_vectorfunc.cc +++ b/sql/item_vectorfunc.cc @@ -23,6 +23,7 @@ #include #include "item.h" +#include "item_vectorfunc.h" key_map Item_func_vec_distance::part_of_sortkey() const { @@ -48,8 +49,18 @@ double Item_func_vec_distance::val_real() return 0; float *v1= (float*)r1->ptr(); float *v2= (float*)r2->ptr(); + return euclidean_vec_distance(v1, v2, (r1->length()) / sizeof(float)); +} + +double euclidean_vec_distance(float *v1, float *v2, size_t v_len) +{ + float *p1= v1; + float *p2= v2; double d= 0; - for (uint i=0; i < r1->length() / sizeof(float); i++) - d+= (v1[i] - v2[i])*(v1[i] - v2[i]); + for (size_t i= 0; i < v_len; p1++, p2++, i++) + { + float dist= *p1 - *p2; + d+= dist * dist; + } return sqrt(d); } diff --git a/sql/item_vectorfunc.h b/sql/item_vectorfunc.h index 0adff0c0580..731f0315245 100644 --- a/sql/item_vectorfunc.h +++ b/sql/item_vectorfunc.h @@ -17,6 +17,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */ /* This file defines all vector functions */ +#include +#include "item.h" #include "lex_string.h" #include "item_func.h" @@ -34,6 +36,7 @@ class Item_func_vec_distance: public Item_real_func { return check_argument_types_or_binary(NULL, 0, arg_count); } + public: Item_func_vec_distance(THD *thd, Item *a, Item *b) :Item_real_func(thd, a, b) {} @@ -51,6 +54,9 @@ public: key_map part_of_sortkey() const override; Item *do_get_copy(THD *thd) const override { return get_item_copy(thd, this); } + virtual ~Item_func_vec_distance() {}; }; + +double euclidean_vec_distance(float *v1, float *v2, size_t v_len); #endif diff --git a/sql/sql_base.cc b/sql/sql_base.cc index 49e90fc2c2a..56b611fc816 100644 --- a/sql/sql_base.cc +++ b/sql/sql_base.cc @@ -9883,7 +9883,7 @@ int TABLE::hlindex_open(uint nr) mysql_mutex_unlock(&s->LOCK_share); TABLE *table= (TABLE*)alloc_root(&mem_root, sizeof(*table)); if (!table || - open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, 0, + open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, EXTRA_RECORD, in_use->open_options, table, 0)) return 1; hlindex= table; @@ -9938,7 +9938,7 @@ int TABLE::hlindex_read_first(uint nr, Item *item, ulonglong limit) DBUG_ASSERT(hlindex->in_use == in_use); - return mhnsw_read_first(this, item, limit); + return mhnsw_read_first(this, key_info + s->keys, item, limit); } int TABLE::hlindex_read_next() diff --git a/sql/sql_class.h b/sql/sql_class.h index 2b6ea6b23b7..609a18cac54 100644 --- a/sql/sql_class.h +++ b/sql/sql_class.h @@ -922,6 +922,11 @@ typedef struct system_variables my_bool binlog_alter_two_phase; Charset_collation_map_st character_set_collations; + + /* Temporary for HNSW tests */ + uint hnsw_max_connection_per_layer; + uint hnsw_ef_constructor; + uint hnsw_ef_search; } SV; /** diff --git a/sql/sys_vars.cc b/sql/sys_vars.cc index dc4ecef9b50..88d380f4b4d 100644 --- a/sql/sys_vars.cc +++ b/sql/sys_vars.cc @@ -7447,3 +7447,23 @@ static Sys_var_ulonglong Sys_binlog_large_commit_threshold( // Allow a smaller minimum value for debug builds to help with testing VALID_RANGE(IF_DBUG(100, 10240) * 1024, ULLONG_MAX), DEFAULT(128 * 1024 * 1024), BLOCK_SIZE(1)); + +/* Temporary for HNSW tests */ +static Sys_var_uint Sys_hnsw_ef_search( + "hnsw_ef_search", + "hnsw_ef_search", + SESSION_VAR(hnsw_ef_search), CMD_LINE(NO_ARG), + VALID_RANGE(0, UINT_MAX), DEFAULT(10), + BLOCK_SIZE(1)); +static Sys_var_uint Sys_hnsw_ef_constructor( + "hnsw_ef_constructor", + "hnsw_ef_constructor", + SESSION_VAR(hnsw_ef_constructor), CMD_LINE(NO_ARG), + VALID_RANGE(0, UINT_MAX), DEFAULT(10), + BLOCK_SIZE(1)); +static Sys_var_uint Sys_hnsw_max_connection_per_layer( + "hnsw_max_connection_per_layer", + "hnsw_max_connection_per_layer", + SESSION_VAR(hnsw_max_connection_per_layer), CMD_LINE(NO_ARG), + VALID_RANGE(0, UINT_MAX), DEFAULT(50), + BLOCK_SIZE(1)); diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index 9fe5378f177..623500a7fe4 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -14,196 +14,814 @@ along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */ +#include #include #include "vector_mhnsw.h" + #include "field.h" +#include "hash.h" #include "item.h" +#include "item_vectorfunc.h" +#include "key.h" +#include "my_base.h" +#include "mysql/psi/psi_base.h" #include "sql_queue.h" -#include + +#define HNSW_MAX_M 10000 const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\ CREATE TABLE i ( \ + layer int not null, \ src varbinary(255) not null, \ - dst varbinary(255) not null, \ - index (src)) \ + neighbors varbinary(10000) not null, \ + index (layer, src)) \ ")}; -static void store_ref(TABLE *t, handler *h, uint n) + +class FVectorRef { - t->hlindex->field[n]->store((char*)h->ref, h->ref_length, &my_charset_bin); +public: + // Shallow ref copy. Used for other ref lookups in HashSet + FVectorRef(uchar *ref, size_t ref_len): ref{ref}, ref_len{ref_len} {} + virtual ~FVectorRef() {} + + static const uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool) + { + *key_len= elem->ref_len; + return elem->ref; + } + + static void free_vector(void *elem) + { + delete (FVectorRef *)elem; + } + + size_t get_ref_len() const { return ref_len; } + const uchar* get_ref() const { return ref; } + +protected: + FVectorRef() = default; + uchar *ref; + size_t ref_len; +}; + +class FVector; + +Hash_set all_vector_set( + PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + +Hash_set all_vector_ref_set( + PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + + +class FVector: public FVectorRef +{ +private: + float *vec; + size_t vec_len; +public: + FVector(): vec(nullptr), vec_len(0) {} + ~FVector() { my_free(this->ref); } + + bool init(const uchar *ref, size_t ref_len, + const float *vec, size_t vec_len) + { + this->ref= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, + ref_len + vec_len * sizeof(float), + MYF(0)); + if (!this->ref) + return true; + + this->vec= reinterpret_cast(this->ref + ref_len); + + memcpy(this->ref, ref, ref_len); + memcpy(this->vec, vec, vec_len * sizeof(float)); + + this->ref_len= ref_len; + this->vec_len= vec_len; + return false; + } + + size_t get_vec_len() const { return vec_len; } + const float* get_vec() const { return vec; } + + double distance_to(const FVector &other) const + { + DBUG_ASSERT(other.vec_len == vec_len); + return euclidean_vec_distance(vec, other.vec, vec_len); + } + + static FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len) + { + FVectorRef tmp{(uchar*)ref, ref_len}; + FVectorRef *v= all_vector_ref_set.find(&tmp); + if (v) + return v; + + // TODO(cvicentiu) memory management. + uchar *buf= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, ref_len, MYF(0)); + memcpy(buf, ref, ref_len); + v= new FVectorRef{buf, ref_len}; + all_vector_ref_set.insert(v); + return v; + } + + static FVector *get_fvector_from_source(TABLE *source, + Field *vect_field, + const FVectorRef &ref) + { + + FVectorRef *v= all_vector_set.find(&ref); + if (v) + return (FVector *)v; + + FVector *new_vector= new FVector; + if (!new_vector) + return nullptr; + + source->file->ha_rnd_pos(source->record[0], + const_cast(ref.get_ref())); + + String buf, *vec; + vec= vect_field->val_str(&buf); + + // TODO(cvicentiu) error checking + new_vector->init(ref.get_ref(), ref.get_ref_len(), + reinterpret_cast(vec->ptr()), + vec->length() / sizeof(float)); + + all_vector_set.insert(new_vector); + + return new_vector; + } +}; + + + + +static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b) +{ + double a_dist= reference->distance_to(*a); + double b_dist= reference->distance_to(*b); + + if (a_dist < b_dist) + return -1; + if (a_dist > b_dist) + return 1; + return 0; } +const bool KEEP_PRUNED_CONNECTIONS=true; +const bool EXTEND_CANDIDATES=true; + +static bool get_neighbours(TABLE *graph, + size_t layer_number, + const FVectorRef &source_node, + List *neighbours); + +static bool select_neighbours(TABLE *source, TABLE *graph, + Field *vect_field, + size_t layer_number, + const FVector &target, + const List &candidates, + size_t max_neighbour_connections, + List *neighbours) +{ + /* + TODO: If the input neighbours list is already sorted in search_layer, then + no need to do additional queue build steps here. + */ + + Hash_set visited(PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + + Queue pq; + Queue pq_discard; + Queue best; + + // TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size. + // This should not be fixed. + pq.init(10000, 0, cmp_vec, &target); + pq_discard.init(10000, 0, cmp_vec, &target); + best.init(max_neighbour_connections, true, cmp_vec, &target); + + // TODO(cvicentiu) error checking. + for (const FVectorRef &candidate : candidates) + { + pq.push(FVector::get_fvector_from_source(source, vect_field, candidate)); + visited.insert(&candidate); + } + + + if (EXTEND_CANDIDATES) + { + for (const FVectorRef &candidate : candidates) + { + List candidate_neighbours; + get_neighbours(graph, layer_number, candidate, &candidate_neighbours); + for (const FVectorRef &extra_candidate : candidate_neighbours) + { + if (visited.find(&extra_candidate)) + continue; + visited.insert(&extra_candidate); + pq.push(FVector::get_fvector_from_source(source, + vect_field, + extra_candidate)); + } + } + } + + DBUG_ASSERT(pq.elements()); + best.push(pq.pop()); + + double best_top = best.top()->distance_to(target); + while (pq.elements() && best.elements() < max_neighbour_connections) + { + const FVector *vec= pq.pop(); + double cur_dist = vec->distance_to(target); + // TODO(cvicentiu) best distance can be cached. + if (cur_dist < best_top) { + + best.push(vec); + best_top = cur_dist; + } + else + pq_discard.push(vec); + } + + if (KEEP_PRUNED_CONNECTIONS) + { + while (pq_discard.elements() && + best.elements() < max_neighbour_connections) + { + best.push(pq_discard.pop()); + } + } + + DBUG_ASSERT(best.elements() <= max_neighbour_connections); + while (best.elements()) { + neighbours->push_front(best.pop()); + } + + return false; +} + +//static bool select_neighbours(TABLE *source, TABLE *graph, +// Field *vect_field, +// size_t layer_number, +// const FVector &target, +// const List &candidates, +// size_t max_neighbour_connections, +// List *neighbours) +//{ +// /* +// TODO: If the input neighbours list is already sorted in search_layer, then +// no need to do additional queue build steps here. +// */ +// +// Queue pq; +// pq.init(candidates.elements, 0, 0, cmp_vec, &target); +// +// // TODO(cvicentiu) error checking. +// for (const FVectorRef &candidate : candidates) +// pq.push(FVector::get_fvector_from_source(source, vect_field, candidate)); +// +// for (size_t i = 0; i < max_neighbour_connections; i++) +// { +// if (!pq.elements()) +// break; +// neighbours->push_back(pq.pop()); +// } +// +// return false; +//} + + +static void dbug_print_vec_ref(const char *prefix, + uint layer, + const FVectorRef &ref) +{ +#ifndef DBUG_OFF + // TODO(cvicentiu) disable this in release build. + char *ref_str= (char *)alloca(ref.get_ref_len() * 2 + 1); + DBUG_ASSERT(ref_str); + char *ptr= ref_str; + for (size_t i = 0; i < ref.get_ref_len(); ptr += 2, i++) + { + snprintf(ptr, 3, "%02x", ref.get_ref()[i]); + } + DBUG_PRINT("VECTOR", ("%s %u %s", prefix, layer, ref_str)); +#endif +} + +static void dbug_print_vec_neigh(uint layer, + const List &neighbors) +{ +#ifndef DBUG_OFF + DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements)); + for (const FVectorRef& ref : neighbors) + { + dbug_print_vec_ref("NEIGH: ", layer, ref); + } +#endif +} + +static bool write_neighbours(TABLE *graph, + size_t layer_number, + const FVectorRef &source_node, + const List &new_neighbours) +{ + DBUG_ASSERT(new_neighbours.elements <= HNSW_MAX_M); + + + size_t total_size= sizeof(uint16_t) + + new_neighbours.elements * source_node.get_ref_len(); + + // Allocate memory for the struct and the flexible array member + char *neighbor_array_bytes= static_cast(alloca(total_size)); + + DBUG_ASSERT(new_neighbours.elements <= INT16_MAX); + *(uint16_t *) neighbor_array_bytes= new_neighbours.elements; + char *pos= neighbor_array_bytes + sizeof(uint16_t); + for (const auto &node: new_neighbours) + { + DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len()); + memcpy(pos, node.get_ref(), node.get_ref_len()); + pos+= node.get_ref_len(); + } + + graph->field[0]->store(layer_number); + graph->field[1]->store_binary( + reinterpret_cast(source_node.get_ref()), + source_node.get_ref_len()); + graph->field[2]->set_null(); + + + uchar *key= (uchar*)alloca(graph->key_info->key_length); + key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); + + int err= graph->file->ha_index_read_map(graph->record[1], key, + HA_WHOLE_KEY, + HA_READ_KEY_EXACT); + + // no record + if (err == HA_ERR_KEY_NOT_FOUND) + { + dbug_print_vec_ref("INSERT ", layer_number, source_node); + graph->field[2]->store_binary(neighbor_array_bytes, total_size); + graph->file->ha_write_row(graph->record[0]); + return false; + } + dbug_print_vec_ref("UPDATE ", layer_number, source_node); + dbug_print_vec_neigh(layer_number, new_neighbours); + + graph->field[2]->store_binary(neighbor_array_bytes, total_size); + graph->file->ha_update_row(graph->record[1], graph->record[0]); + return false; +} + + +static bool get_neighbours(TABLE *graph, + size_t layer_number, + const FVectorRef &source_node, + List *neighbours) +{ + // TODO(cvicentiu) This allocation need not happen in this function. + uchar *key= (uchar*)alloca(graph->key_info->key_length); + + graph->field[0]->store(layer_number); + graph->field[1]->store_binary( + reinterpret_cast(source_node.get_ref()), + source_node.get_ref_len()); + graph->field[2]->set_null(); + key_copy(key, graph->record[0], + graph->key_info, graph->key_info->key_length); + if ((graph->file->ha_index_read_map(graph->record[0], key, + HA_WHOLE_KEY, + HA_READ_KEY_EXACT))) + return true; + + //TODO This does two memcpys, one should use str's buffer. + String strbuf; + String *str= graph->field[2]->val_str(&strbuf); + + // All ref should have same length + uint ref_length= source_node.get_ref_len(); + + const uchar *neigh_arr_bytes= reinterpret_cast(str->ptr()); + uint16_t number_of_neighbours= + *reinterpret_cast(neigh_arr_bytes); + if (number_of_neighbours != (str->length() - sizeof(uint16_t)) / ref_length) + { + /* + neighbours number does not match the data length, + should not happen, possible corrupted HNSW index + */ + DBUG_ASSERT(0); // TODO(cvicentiu) remove this after testing. + return true; + } + + const uchar *pos = neigh_arr_bytes + sizeof(uint16_t); + for (uint16_t i= 0; i < number_of_neighbours; i++) + { + neighbours->push_back(FVector::get_fvector_ref(pos, ref_length)); + pos+= ref_length; + } + + return false; +} + + +static bool update_second_degree_neighbors(TABLE *source, + Field *vec_field, + TABLE *graph, + size_t layer_number, + uint max_neighbours, + const FVectorRef &source_node, + const List &neighbours) +{ + //dbug_print_vec_ref("Updating second degree neighbours", layer_number, source_node); + //dbug_print_vec_neigh(layer_number, neighbours); + for (const FVectorRef &neigh: neighbours) + { + List new_neighbours; + get_neighbours(graph, layer_number, neigh, &new_neighbours); + new_neighbours.push_back(&source_node); + write_neighbours(graph, layer_number, neigh, new_neighbours); + } + + for (const FVectorRef &neigh: neighbours) + { + List new_neighbours; + get_neighbours(graph, layer_number, neigh, &new_neighbours); + // TODO(cvicentiu) get_fvector_from_source results must not need to be freed. + FVector *neigh_vec = FVector::get_fvector_from_source(source, vec_field, neigh); + + if (new_neighbours.elements > max_neighbours) + { + // shrink the neighbours + List selected; + select_neighbours(source, graph, vec_field, layer_number, + *neigh_vec, new_neighbours, + max_neighbours, &selected); + write_neighbours(graph, layer_number, neigh, selected); + } + + // release memory + new_neighbours.empty(); + } + + return false; +} + + +static bool update_neighbours(TABLE *source, + TABLE *graph, + Field *vec_field, + size_t layer_number, + uint max_neighbours, + const FVectorRef &source_node, + const List &neighbours) +{ + // 1. update node's neighbours + write_neighbours(graph, layer_number, source_node, neighbours); + // 2. update node's neighbours' neighbours (shrink before update) + update_second_degree_neighbors(source, vec_field, graph, layer_number, + max_neighbours, source_node, neighbours); + return false; +} + + +static bool search_layer(TABLE *source, + TABLE *graph, + Field *vec_field, + const FVector &target, + const List &start_nodes, + uint max_candidates_return, + size_t layer, + List *result) +{ + DBUG_ASSERT(start_nodes.elements > 0); + // Result list must be empty, otherwise there's a risk of memory leak + DBUG_ASSERT(result->elements == 0); + + Queue candidates; + Queue best; + //TODO(cvicentiu) Fix this hash method. + Hash_set visited(PSI_INSTRUMENT_MEM, &my_charset_bin, + 1000, 0, 0, + (my_hash_get_key)FVectorRef::get_key, + NULL, + HASH_UNIQUE); + + candidates.init(10000, false, cmp_vec, &target); + best.init(max_candidates_return, true, cmp_vec, &target); + + for (const FVectorRef &node : start_nodes) + { + FVector *v= FVector::get_fvector_from_source(source, vec_field, node); + candidates.push(v); + if (best.elements() < max_candidates_return) + best.push(v); + else if (target.distance_to(*v) > target.distance_to(*best.top())) { + best.replace_top(v); + } + visited.insert(v); + dbug_print_vec_ref("INSERTING node in visited: ", layer, node); + } + + double furthest_best = target.distance_to(*best.top()); + while (candidates.elements()) + { + const FVector &cur_vec= *candidates.pop(); + double cur_distance = target.distance_to(cur_vec); + if (cur_distance > furthest_best && best.elements() == max_candidates_return) + { + break; // All possible candidates are worse than what we have. + // Can't get better. + } + + List neighbours; + get_neighbours(graph, layer, cur_vec, &neighbours); + + for (const FVectorRef &neigh: neighbours) + { + if (visited.find(&neigh)) + continue; + + FVector *clone = FVector::get_fvector_from_source(source, vec_field, neigh); + // TODO(cvicentiu) mem ownershipw... + visited.insert(clone); + if (best.elements() < max_candidates_return) + { + candidates.push(clone); + best.push(clone); + furthest_best = target.distance_to(*best.top()); + } + else if (target.distance_to(*clone) < furthest_best) + { + best.replace_top(clone); + candidates.push(clone); + furthest_best = target.distance_to(*best.top()); + } + } + neighbours.empty(); + } + DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements())); + + while (best.elements()) + { + // TODO(cvicentiu) FVector memory leak. + // TODO(cvicentiu) this is n*log(n), we need a queue iterator. + result->push_front(best.pop()); + } + + return false; +} + + +std::mt19937 gen(42); // TODO(cvicentiu) seeded with 42 for now, this should + // use a rnd service + int mhnsw_insert(TABLE *table, KEY *keyinfo) { TABLE *graph= table->hlindex; MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); - Field *field= keyinfo->key_part->field; - String buf, *res= field->val_str(&buf); + Field *vec_field= keyinfo->key_part->field; + String buf, *res= vec_field->val_str(&buf); handler *h= table->file; int err= 0; - dbug_tmp_restore_column_map(&table->read_set, old_map); /* metadata are checked on open */ DBUG_ASSERT(graph); DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); DBUG_ASSERT(keyinfo->usable_key_parts == 1); - DBUG_ASSERT(field->binary()); - DBUG_ASSERT(field->cmp_type() == STRING_RESULT); + DBUG_ASSERT(vec_field->binary()); + DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL - DBUG_ASSERT(h->ref_length <= graph->field[0]->field_length); DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length); + DBUG_ASSERT(h->ref_length <= graph->field[2]->field_length); if (res->length() == 0 || res->length() % 4) return 1; - // let's do every node to every node - h->position(table->record[0]); - graph->field[0]->store(1); - store_ref(table, h, 0); + const double NORMALIZATION_FACTOR = 1 / std::log(1.0 * + table->in_use->variables.hnsw_max_connection_per_layer); - if (h->lookup_handler->ha_rnd_init(1)) - return 1; - while (! ((err= h->lookup_handler->ha_rnd_next(h->lookup_buffer)))) - { - h->lookup_handler->position(h->lookup_buffer); - if (graph->field[0]->cmp(h->lookup_handler->ref) == 0) - continue; - store_ref(table, h->lookup_handler, 1); - if ((err= graph->file->ha_write_row(graph->record[0]))) - break; - } - h->lookup_handler->ha_rnd_end(); + if ((err= h->ha_rnd_init(1))) + return err; - return err == HA_ERR_END_OF_FILE ? 0 : err; -} - -struct Node -{ - float distance; - uchar ref[1000]; -}; - -static int cmp_float(void *, const Node *a, const Node *b) -{ - return a->distance < b->distance ? -1 : a->distance == b->distance ? 0 : 1; -} - -int mhnsw_read_first(TABLE *table, Item *dist, ulonglong limit) -{ - TABLE *graph= table->hlindex; - Queue todo, result; - Node *cur; - String *str, strbuf; - const size_t ref_length= table->file->ref_length; - const size_t element_size= ref_length + sizeof(float); - uchar *key= (uchar*)alloca(ref_length + 32); - Hash_set visited(PSI_INSTRUMENT_MEM, &my_charset_bin, limit, - sizeof(float), ref_length, 0, 0, HASH_UNIQUE); - uint keylen; - int err= 0; - - DBUG_ASSERT(graph); - - if (todo.init(1000, 0, cmp_float)) // XXX + autoextent - return HA_ERR_OUT_OF_MEM; - - if (result.init(limit, 1, cmp_float)) - return HA_ERR_OUT_OF_MEM; if ((err= graph->file->ha_index_init(0, 1))) return err; - SCOPE_EXIT([graph](){ graph->file->ha_index_end(); }); - - // 1. read a start row + longlong max_layer; if ((err= graph->file->ha_index_last(graph->record[0]))) - return err; - - if (!(str= graph->field[0]->val_str(&strbuf))) - return HA_ERR_CRASHED; - - DBUG_ASSERT(str->length() == ref_length); - - cur= (Node*)table->in_use->alloc(element_size); - memcpy(cur->ref, str->ptr(), ref_length); - - if ((err= table->file->ha_rnd_init(0))) - return err; - - if ((err= table->file->ha_rnd_pos(table->record[0], cur->ref))) - return HA_ERR_CRASHED; - - // 2. add it to the todo - cur->distance= dist->val_real(); - if (dist->is_null()) - return HA_ERR_END_OF_FILE; - todo.push(cur); - visited.insert(cur); - - while (todo.elements()) { - // 3. pick the top node from the todo - cur= todo.pop(); - - // 4. add it to the result - if (result.is_full()) + if (err != HA_ERR_END_OF_FILE) { - // 5. if not added, greedy search done - if (cur->distance > result.top()->distance) - break; - result.replace_top(cur); + graph->file->ha_index_end(); + return err; } - else - result.push(cur); + // First insert! + h->position(table->record[0]); + write_neighbours(graph, 0, {h->ref, h->ref_length}, {}); - float threshold= result.is_full() ? result.top()->distance : FLT_MAX; + h->ha_rnd_end(); + graph->file->ha_index_end(); + return 0; // TODO (error during store_link) + } + else + max_layer= graph->field[0]->val_int(); - // 6. add all its [yet unvisited] neighbours to the todo heap - keylen= graph->field[0]->get_key_image(key, ref_length, Field::itRAW); - if ((err= graph->file->ha_index_read_map(graph->record[0], key, 3, - HA_READ_KEY_EXACT))) - return HA_ERR_CRASHED; + FVector target; + h->position(table->record[0]); + // TODO (cvicentiu) Error checking. + target.init(h->ref, h->ref_length, + reinterpret_cast(res->ptr()), + res->length() / sizeof(float)); - do { - if (!(str= graph->field[1]->val_str(&strbuf))) - return HA_ERR_CRASHED; - if (visited.find(str->ptr(), ref_length)) - continue; + std::uniform_real_distribution<> dis(0.0, 1.0); + double new_num= dis(gen); + double log= -std::log(new_num) * NORMALIZATION_FACTOR; + longlong new_node_layer= std::floor(log); - if ((err= table->file->ha_rnd_pos(table->record[0], (uchar*)str->ptr()))) - return HA_ERR_CRASHED; + List start_nodes; - float distance= dist->val_real(); - if (distance > threshold) - continue; + String ref_str, *ref_ptr; + ref_ptr= graph->field[1]->val_str(&ref_str); - cur= (Node*)table->in_use->alloc(element_size); - cur->distance= distance; - memcpy(cur->ref, str->ptr(), ref_length); - todo.push(cur); - visited.insert(cur); - } while (!graph->file->ha_index_next_same(graph->record[0], key, keylen)); - // 7. goto 3 + FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()}; + //FVector *start_node= start_node_ref.get_fvector_from_source(table, vec_field); + + // TODO(cvicentiu) error checking. Also make sure we use a random start node + // in last layer. + start_nodes.push_back(&start_node_ref); + // TODO start_nodes needs to have one element in it. + for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) + { + List candidates; + search_layer(table, graph, vec_field, target, start_nodes, + table->in_use->variables.hnsw_ef_constructor, cur_layer, + &candidates); + start_nodes.empty(); + start_nodes.push_back(candidates.head()); + //candidates.delete_elements(); + //TODO(cvicentiu) memory leak } + for (longlong cur_layer= std::min(max_layer, new_node_layer); + cur_layer >= 0; cur_layer--) + { + List candidates; + List neighbours; + search_layer(table, graph, vec_field, target, start_nodes, + table->in_use->variables.hnsw_ef_constructor, + cur_layer, &candidates); + // release vectors + start_nodes.empty(); + + uint max_neighbours= (cur_layer == 0) ? + table->in_use->variables.hnsw_max_connection_per_layer * 2 + : table->in_use->variables.hnsw_max_connection_per_layer; + + select_neighbours(table, graph, vec_field, cur_layer, + target, candidates, + max_neighbours, &neighbours); + update_neighbours(table, graph, vec_field, cur_layer, max_neighbours, + target, neighbours); + start_nodes= candidates; + } + start_nodes.empty(); + + for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer; + cur_layer++) + { + write_neighbours(graph, cur_layer, target, {}); + } + + + h->ha_rnd_end(); + graph->file->ha_index_end(); + dbug_tmp_restore_column_map(&table->read_set, old_map); + + return err == HA_ERR_END_OF_FILE ? 0 : err; +} + + +int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) +{ + TABLE *graph= table->hlindex; + MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); + // TODO(cvicentiu) onlye one hlindex now. + Field *vec_field= keyinfo->key_part->field; + Item_func_vec_distance *fun= (Item_func_vec_distance *)dist; + String buf, *res= fun->arguments()[1]->val_str(&buf); + handler *h= table->file; + + //TODO(scope_exit) + int err; + if ((err= h->ha_rnd_init(0))) + return err; + + if ((err= graph->file->ha_index_init(0, 1))) + return err; + + h->position(table->record[0]); + FVector target; + target.init(h->ref, + h->ref_length, + reinterpret_cast(res->ptr()), + res->length() / sizeof(float)); + + List candidates; + List start_nodes; + + longlong max_layer; + if ((err= graph->file->ha_index_last(graph->record[0]))) + { + if (err != HA_ERR_END_OF_FILE) + { + graph->file->ha_index_end(); + return err; + } + h->ha_rnd_end(); + graph->file->ha_index_end(); + return 0; // TODO (error during store_link) + } + else + max_layer= graph->field[0]->val_int(); + + String ref_str, *ref_ptr; + ref_ptr= graph->field[1]->val_str(&ref_str); + FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()}; + // TODO(cvicentiu) error checking. Also make sure we use a random start node + // in last layer. + start_nodes.push_back(&start_node_ref); + + ulonglong ef_search= MY_MAX( + table->in_use->variables.hnsw_ef_search, limit); + + for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) + { + search_layer(table, graph, vec_field, target, start_nodes, ef_search, + cur_layer, &candidates); + start_nodes.empty(); + //start_nodes.delete_elements(); + start_nodes.push_back(candidates.head()); + //candidates.delete_elements(); + candidates.empty(); + //TODO(cvicentiu) memleak. + } + + search_layer(table, graph, vec_field, target, start_nodes, + ef_search, 0, &candidates); + // 8. return results - Node **context= (Node**)table->in_use->alloc(sizeof(Node**)*result.elements()+1); + FVectorRef **context= (FVectorRef**)table->in_use->alloc( + sizeof(FVectorRef*) * (limit + 1)); graph->context= context; - Node **ptr= context+result.elements(); - *ptr= 0; - while (result.elements()) - *--ptr= result.pop(); + FVectorRef **ptr= context; + while (limit--) + *ptr++= candidates.pop(); + *ptr= nullptr; - return mhnsw_read_next(table); + err= mhnsw_read_next(table); + graph->file->ha_index_end(); + + // TODO release vectors after query + + dbug_tmp_restore_column_map(&table->read_set, old_map); + return err; } int mhnsw_read_next(TABLE *table) { - Node ***context= (Node***)&table->hlindex->context; - if (**context) - return table->file->ha_rnd_pos(table->record[0], (*(*context)++)->ref); + FVectorRef ***context= (FVectorRef ***)&table->hlindex->context; + FVectorRef *cur_vec= **context; + if (cur_vec) + { + int err= table->file->ha_rnd_pos(table->record[0], + (uchar *)(cur_vec)->get_ref()); + // release vectors + // delete cur_vec; + + (*context)++; + return err; + } return HA_ERR_END_OF_FILE; } diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index 74a0bd8ce42..141512244b7 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -15,10 +15,14 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */ +#include +#include "item.h" +#include "m_string.h" +#include "structs.h" #include "table.h" extern const LEX_CSTRING mhnsw_hlindex_table; int mhnsw_insert(TABLE *table, KEY *keyinfo); -int mhnsw_read_first(TABLE *table, Item *dist, ulonglong limit); +int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit); int mhnsw_read_next(TABLE *table);