mirror of
https://github.com/MariaDB/server.git
synced 2025-08-08 11:22:35 +03:00
Initial HNSW implementation
This commit includes the work done in collaboration with Hugo Wen from Amazon: MDEV-33408 Alter HNSW graph storage and fix memory leak This commit changes the way HNSW graph information is stored in the second table. Instead of storing connections as separate records, it now stores neighbors for each node, leading to significant performance improvements and storage savings. Comparing with the previous approach, the insert speed is 5 times faster, search speed improves by 23%, and storage usage is reduced by 73%, based on ann-benchmark tests with random-xs-20-euclidean and random-s-100-euclidean datasets. Additionally, in previous code, vector objects were not released after use, resulting in excessive memory consumption (over 20GB for building the index with 90,000 records), preventing tests with large datasets. Now ensure that vectors are released appropriately during the insert and search functions. Note there are still some vectors that need to be cleaned up after search query completion. Needs to be addressed in a future commit. All new code of the whole pull request, including one or several files that are either new files or modified ones, are contributed under the BSD-new license. I am contributing on behalf of my employer Amazon Web Services, Inc. As well as the commit: Introduce session variables to manage HNSW index parameters Three variables: hnsw_max_connection_per_layer hnsw_ef_constructor hnsw_ef_search ann-benchmark tool is also updated to support these variables in commit https://github.com/HugoWenTD/ann-benchmarks/commit/e09784e for branch https://github.com/HugoWenTD/ann-benchmarks/tree/mariadb-configurable All new code of the whole pull request, including one or several files that are either new files or modified ones, are contributed under the BSD-new license. I am contributing on behalf of my employer Amazon Web Services, Inc. Co-authored-by: Hugo Wen <wenhug@amazon.com>
This commit is contained in:
committed by
Sergei Golubchik
parent
26e5654301
commit
88839e71a3
@@ -412,6 +412,11 @@ The following specify which files/extra groups are read (specified before remain
|
|||||||
height-balanced, DOUBLE_PREC_HB - double precision
|
height-balanced, DOUBLE_PREC_HB - double precision
|
||||||
height-balanced, JSON_HB - height-balanced, stored as
|
height-balanced, JSON_HB - height-balanced, stored as
|
||||||
JSON
|
JSON
|
||||||
|
--hnsw-ef-constructor
|
||||||
|
hnsw_ef_constructor
|
||||||
|
--hnsw-ef-search hnsw_ef_search
|
||||||
|
--hnsw-max-connection-per-layer
|
||||||
|
hnsw_max_connection_per_layer
|
||||||
--host-cache-size=# How many host names should be cached to avoid resolving
|
--host-cache-size=# How many host names should be cached to avoid resolving
|
||||||
(Automatically configured unless set explicitly)
|
(Automatically configured unless set explicitly)
|
||||||
--idle-readonly-transaction-timeout=#
|
--idle-readonly-transaction-timeout=#
|
||||||
@@ -1732,6 +1737,9 @@ gtid-strict-mode FALSE
|
|||||||
help TRUE
|
help TRUE
|
||||||
histogram-size 254
|
histogram-size 254
|
||||||
histogram-type JSON_HB
|
histogram-type JSON_HB
|
||||||
|
hnsw-ef-constructor 10
|
||||||
|
hnsw-ef-search 10
|
||||||
|
hnsw-max-connection-per-layer 50
|
||||||
host-cache-size 279
|
host-cache-size 279
|
||||||
idle-readonly-transaction-timeout 0
|
idle-readonly-transaction-timeout 0
|
||||||
idle-transaction-timeout 0
|
idle-transaction-timeout 0
|
||||||
|
@@ -1432,6 +1432,36 @@ NUMERIC_BLOCK_SIZE NULL
|
|||||||
ENUM_VALUE_LIST SINGLE_PREC_HB,DOUBLE_PREC_HB,JSON_HB
|
ENUM_VALUE_LIST SINGLE_PREC_HB,DOUBLE_PREC_HB,JSON_HB
|
||||||
READ_ONLY NO
|
READ_ONLY NO
|
||||||
COMMAND_LINE_ARGUMENT REQUIRED
|
COMMAND_LINE_ARGUMENT REQUIRED
|
||||||
|
VARIABLE_NAME HNSW_EF_CONSTRUCTOR
|
||||||
|
VARIABLE_SCOPE SESSION
|
||||||
|
VARIABLE_TYPE INT UNSIGNED
|
||||||
|
VARIABLE_COMMENT hnsw_ef_constructor
|
||||||
|
NUMERIC_MIN_VALUE 0
|
||||||
|
NUMERIC_MAX_VALUE 4294967295
|
||||||
|
NUMERIC_BLOCK_SIZE 1
|
||||||
|
ENUM_VALUE_LIST NULL
|
||||||
|
READ_ONLY NO
|
||||||
|
COMMAND_LINE_ARGUMENT NONE
|
||||||
|
VARIABLE_NAME HNSW_EF_SEARCH
|
||||||
|
VARIABLE_SCOPE SESSION
|
||||||
|
VARIABLE_TYPE INT UNSIGNED
|
||||||
|
VARIABLE_COMMENT hnsw_ef_search
|
||||||
|
NUMERIC_MIN_VALUE 0
|
||||||
|
NUMERIC_MAX_VALUE 4294967295
|
||||||
|
NUMERIC_BLOCK_SIZE 1
|
||||||
|
ENUM_VALUE_LIST NULL
|
||||||
|
READ_ONLY NO
|
||||||
|
COMMAND_LINE_ARGUMENT NONE
|
||||||
|
VARIABLE_NAME HNSW_MAX_CONNECTION_PER_LAYER
|
||||||
|
VARIABLE_SCOPE SESSION
|
||||||
|
VARIABLE_TYPE INT UNSIGNED
|
||||||
|
VARIABLE_COMMENT hnsw_max_connection_per_layer
|
||||||
|
NUMERIC_MIN_VALUE 0
|
||||||
|
NUMERIC_MAX_VALUE 4294967295
|
||||||
|
NUMERIC_BLOCK_SIZE 1
|
||||||
|
ENUM_VALUE_LIST NULL
|
||||||
|
READ_ONLY NO
|
||||||
|
COMMAND_LINE_ARGUMENT NONE
|
||||||
VARIABLE_NAME HOSTNAME
|
VARIABLE_NAME HOSTNAME
|
||||||
VARIABLE_SCOPE GLOBAL
|
VARIABLE_SCOPE GLOBAL
|
||||||
VARIABLE_TYPE VARCHAR
|
VARIABLE_TYPE VARCHAR
|
||||||
|
@@ -6634,7 +6634,6 @@ public:
|
|||||||
#include "item_subselect.h"
|
#include "item_subselect.h"
|
||||||
#include "item_xmlfunc.h"
|
#include "item_xmlfunc.h"
|
||||||
#include "item_jsonfunc.h"
|
#include "item_jsonfunc.h"
|
||||||
#include "item_vectorfunc.h"
|
|
||||||
#include "item_create.h"
|
#include "item_create.h"
|
||||||
#include "item_vers.h"
|
#include "item_vers.h"
|
||||||
#endif
|
#endif
|
||||||
|
@@ -36,6 +36,7 @@
|
|||||||
#include "sp.h"
|
#include "sp.h"
|
||||||
#include "sql_time.h"
|
#include "sql_time.h"
|
||||||
#include "sql_type_geom.h"
|
#include "sql_type_geom.h"
|
||||||
|
#include "item_vectorfunc.h"
|
||||||
#include <mysql/plugin_function.h>
|
#include <mysql/plugin_function.h>
|
||||||
|
|
||||||
|
|
||||||
|
@@ -23,6 +23,7 @@
|
|||||||
|
|
||||||
#include <my_global.h>
|
#include <my_global.h>
|
||||||
#include "item.h"
|
#include "item.h"
|
||||||
|
#include "item_vectorfunc.h"
|
||||||
|
|
||||||
key_map Item_func_vec_distance::part_of_sortkey() const
|
key_map Item_func_vec_distance::part_of_sortkey() const
|
||||||
{
|
{
|
||||||
@@ -48,8 +49,18 @@ double Item_func_vec_distance::val_real()
|
|||||||
return 0;
|
return 0;
|
||||||
float *v1= (float*)r1->ptr();
|
float *v1= (float*)r1->ptr();
|
||||||
float *v2= (float*)r2->ptr();
|
float *v2= (float*)r2->ptr();
|
||||||
|
return euclidean_vec_distance(v1, v2, (r1->length()) / sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
double euclidean_vec_distance(float *v1, float *v2, size_t v_len)
|
||||||
|
{
|
||||||
|
float *p1= v1;
|
||||||
|
float *p2= v2;
|
||||||
double d= 0;
|
double d= 0;
|
||||||
for (uint i=0; i < r1->length() / sizeof(float); i++)
|
for (size_t i= 0; i < v_len; p1++, p2++, i++)
|
||||||
d+= (v1[i] - v2[i])*(v1[i] - v2[i]);
|
{
|
||||||
|
float dist= *p1 - *p2;
|
||||||
|
d+= dist * dist;
|
||||||
|
}
|
||||||
return sqrt(d);
|
return sqrt(d);
|
||||||
}
|
}
|
||||||
|
@@ -17,6 +17,8 @@
|
|||||||
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */
|
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */
|
||||||
|
|
||||||
/* This file defines all vector functions */
|
/* This file defines all vector functions */
|
||||||
|
#include <my_global.h>
|
||||||
|
#include "item.h"
|
||||||
#include "lex_string.h"
|
#include "lex_string.h"
|
||||||
#include "item_func.h"
|
#include "item_func.h"
|
||||||
|
|
||||||
@@ -34,6 +36,7 @@ class Item_func_vec_distance: public Item_real_func
|
|||||||
{
|
{
|
||||||
return check_argument_types_or_binary(NULL, 0, arg_count);
|
return check_argument_types_or_binary(NULL, 0, arg_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Item_func_vec_distance(THD *thd, Item *a, Item *b)
|
Item_func_vec_distance(THD *thd, Item *a, Item *b)
|
||||||
:Item_real_func(thd, a, b) {}
|
:Item_real_func(thd, a, b) {}
|
||||||
@@ -51,6 +54,9 @@ public:
|
|||||||
key_map part_of_sortkey() const override;
|
key_map part_of_sortkey() const override;
|
||||||
Item *do_get_copy(THD *thd) const override
|
Item *do_get_copy(THD *thd) const override
|
||||||
{ return get_item_copy<Item_func_vec_distance>(thd, this); }
|
{ return get_item_copy<Item_func_vec_distance>(thd, this); }
|
||||||
|
virtual ~Item_func_vec_distance() {};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
double euclidean_vec_distance(float *v1, float *v2, size_t v_len);
|
||||||
#endif
|
#endif
|
||||||
|
@@ -9883,7 +9883,7 @@ int TABLE::hlindex_open(uint nr)
|
|||||||
mysql_mutex_unlock(&s->LOCK_share);
|
mysql_mutex_unlock(&s->LOCK_share);
|
||||||
TABLE *table= (TABLE*)alloc_root(&mem_root, sizeof(*table));
|
TABLE *table= (TABLE*)alloc_root(&mem_root, sizeof(*table));
|
||||||
if (!table ||
|
if (!table ||
|
||||||
open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, 0,
|
open_table_from_share(in_use, s->hlindex, &empty_clex_str, db_stat, EXTRA_RECORD,
|
||||||
in_use->open_options, table, 0))
|
in_use->open_options, table, 0))
|
||||||
return 1;
|
return 1;
|
||||||
hlindex= table;
|
hlindex= table;
|
||||||
@@ -9938,7 +9938,7 @@ int TABLE::hlindex_read_first(uint nr, Item *item, ulonglong limit)
|
|||||||
|
|
||||||
DBUG_ASSERT(hlindex->in_use == in_use);
|
DBUG_ASSERT(hlindex->in_use == in_use);
|
||||||
|
|
||||||
return mhnsw_read_first(this, item, limit);
|
return mhnsw_read_first(this, key_info + s->keys, item, limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
int TABLE::hlindex_read_next()
|
int TABLE::hlindex_read_next()
|
||||||
|
@@ -922,6 +922,11 @@ typedef struct system_variables
|
|||||||
my_bool binlog_alter_two_phase;
|
my_bool binlog_alter_two_phase;
|
||||||
|
|
||||||
Charset_collation_map_st character_set_collations;
|
Charset_collation_map_st character_set_collations;
|
||||||
|
|
||||||
|
/* Temporary for HNSW tests */
|
||||||
|
uint hnsw_max_connection_per_layer;
|
||||||
|
uint hnsw_ef_constructor;
|
||||||
|
uint hnsw_ef_search;
|
||||||
} SV;
|
} SV;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@@ -7447,3 +7447,23 @@ static Sys_var_ulonglong Sys_binlog_large_commit_threshold(
|
|||||||
// Allow a smaller minimum value for debug builds to help with testing
|
// Allow a smaller minimum value for debug builds to help with testing
|
||||||
VALID_RANGE(IF_DBUG(100, 10240) * 1024, ULLONG_MAX),
|
VALID_RANGE(IF_DBUG(100, 10240) * 1024, ULLONG_MAX),
|
||||||
DEFAULT(128 * 1024 * 1024), BLOCK_SIZE(1));
|
DEFAULT(128 * 1024 * 1024), BLOCK_SIZE(1));
|
||||||
|
|
||||||
|
/* Temporary for HNSW tests */
|
||||||
|
static Sys_var_uint Sys_hnsw_ef_search(
|
||||||
|
"hnsw_ef_search",
|
||||||
|
"hnsw_ef_search",
|
||||||
|
SESSION_VAR(hnsw_ef_search), CMD_LINE(NO_ARG),
|
||||||
|
VALID_RANGE(0, UINT_MAX), DEFAULT(10),
|
||||||
|
BLOCK_SIZE(1));
|
||||||
|
static Sys_var_uint Sys_hnsw_ef_constructor(
|
||||||
|
"hnsw_ef_constructor",
|
||||||
|
"hnsw_ef_constructor",
|
||||||
|
SESSION_VAR(hnsw_ef_constructor), CMD_LINE(NO_ARG),
|
||||||
|
VALID_RANGE(0, UINT_MAX), DEFAULT(10),
|
||||||
|
BLOCK_SIZE(1));
|
||||||
|
static Sys_var_uint Sys_hnsw_max_connection_per_layer(
|
||||||
|
"hnsw_max_connection_per_layer",
|
||||||
|
"hnsw_max_connection_per_layer",
|
||||||
|
SESSION_VAR(hnsw_max_connection_per_layer), CMD_LINE(NO_ARG),
|
||||||
|
VALID_RANGE(0, UINT_MAX), DEFAULT(50),
|
||||||
|
BLOCK_SIZE(1));
|
||||||
|
@@ -14,196 +14,814 @@
|
|||||||
along with this program; if not, write to the Free Software
|
along with this program; if not, write to the Free Software
|
||||||
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA
|
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA
|
||||||
*/
|
*/
|
||||||
|
#include <random>
|
||||||
|
|
||||||
#include <my_global.h>
|
#include <my_global.h>
|
||||||
#include "vector_mhnsw.h"
|
#include "vector_mhnsw.h"
|
||||||
|
|
||||||
#include "field.h"
|
#include "field.h"
|
||||||
|
#include "hash.h"
|
||||||
#include "item.h"
|
#include "item.h"
|
||||||
|
#include "item_vectorfunc.h"
|
||||||
|
#include "key.h"
|
||||||
|
#include "my_base.h"
|
||||||
|
#include "mysql/psi/psi_base.h"
|
||||||
#include "sql_queue.h"
|
#include "sql_queue.h"
|
||||||
#include <scope.h>
|
|
||||||
|
#define HNSW_MAX_M 10000
|
||||||
|
|
||||||
const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
|
const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
|
||||||
CREATE TABLE i ( \
|
CREATE TABLE i ( \
|
||||||
|
layer int not null, \
|
||||||
src varbinary(255) not null, \
|
src varbinary(255) not null, \
|
||||||
dst varbinary(255) not null, \
|
neighbors varbinary(10000) not null, \
|
||||||
index (src)) \
|
index (layer, src)) \
|
||||||
")};
|
")};
|
||||||
|
|
||||||
static void store_ref(TABLE *t, handler *h, uint n)
|
|
||||||
|
class FVectorRef
|
||||||
{
|
{
|
||||||
t->hlindex->field[n]->store((char*)h->ref, h->ref_length, &my_charset_bin);
|
public:
|
||||||
|
// Shallow ref copy. Used for other ref lookups in HashSet
|
||||||
|
FVectorRef(uchar *ref, size_t ref_len): ref{ref}, ref_len{ref_len} {}
|
||||||
|
virtual ~FVectorRef() {}
|
||||||
|
|
||||||
|
static const uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool)
|
||||||
|
{
|
||||||
|
*key_len= elem->ref_len;
|
||||||
|
return elem->ref;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void free_vector(void *elem)
|
||||||
|
{
|
||||||
|
delete (FVectorRef *)elem;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_ref_len() const { return ref_len; }
|
||||||
|
const uchar* get_ref() const { return ref; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FVectorRef() = default;
|
||||||
|
uchar *ref;
|
||||||
|
size_t ref_len;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FVector;
|
||||||
|
|
||||||
|
Hash_set<FVectorRef> all_vector_set(
|
||||||
|
PSI_INSTRUMENT_MEM, &my_charset_bin,
|
||||||
|
1000, 0, 0,
|
||||||
|
(my_hash_get_key)FVectorRef::get_key,
|
||||||
|
NULL,
|
||||||
|
HASH_UNIQUE);
|
||||||
|
|
||||||
|
Hash_set<FVectorRef> all_vector_ref_set(
|
||||||
|
PSI_INSTRUMENT_MEM, &my_charset_bin,
|
||||||
|
1000, 0, 0,
|
||||||
|
(my_hash_get_key)FVectorRef::get_key,
|
||||||
|
NULL,
|
||||||
|
HASH_UNIQUE);
|
||||||
|
|
||||||
|
|
||||||
|
class FVector: public FVectorRef
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
float *vec;
|
||||||
|
size_t vec_len;
|
||||||
|
public:
|
||||||
|
FVector(): vec(nullptr), vec_len(0) {}
|
||||||
|
~FVector() { my_free(this->ref); }
|
||||||
|
|
||||||
|
bool init(const uchar *ref, size_t ref_len,
|
||||||
|
const float *vec, size_t vec_len)
|
||||||
|
{
|
||||||
|
this->ref= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED,
|
||||||
|
ref_len + vec_len * sizeof(float),
|
||||||
|
MYF(0));
|
||||||
|
if (!this->ref)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
this->vec= reinterpret_cast<float *>(this->ref + ref_len);
|
||||||
|
|
||||||
|
memcpy(this->ref, ref, ref_len);
|
||||||
|
memcpy(this->vec, vec, vec_len * sizeof(float));
|
||||||
|
|
||||||
|
this->ref_len= ref_len;
|
||||||
|
this->vec_len= vec_len;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_vec_len() const { return vec_len; }
|
||||||
|
const float* get_vec() const { return vec; }
|
||||||
|
|
||||||
|
double distance_to(const FVector &other) const
|
||||||
|
{
|
||||||
|
DBUG_ASSERT(other.vec_len == vec_len);
|
||||||
|
return euclidean_vec_distance(vec, other.vec, vec_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len)
|
||||||
|
{
|
||||||
|
FVectorRef tmp{(uchar*)ref, ref_len};
|
||||||
|
FVectorRef *v= all_vector_ref_set.find(&tmp);
|
||||||
|
if (v)
|
||||||
|
return v;
|
||||||
|
|
||||||
|
// TODO(cvicentiu) memory management.
|
||||||
|
uchar *buf= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, ref_len, MYF(0));
|
||||||
|
memcpy(buf, ref, ref_len);
|
||||||
|
v= new FVectorRef{buf, ref_len};
|
||||||
|
all_vector_ref_set.insert(v);
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FVector *get_fvector_from_source(TABLE *source,
|
||||||
|
Field *vect_field,
|
||||||
|
const FVectorRef &ref)
|
||||||
|
{
|
||||||
|
|
||||||
|
FVectorRef *v= all_vector_set.find(&ref);
|
||||||
|
if (v)
|
||||||
|
return (FVector *)v;
|
||||||
|
|
||||||
|
FVector *new_vector= new FVector;
|
||||||
|
if (!new_vector)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
source->file->ha_rnd_pos(source->record[0],
|
||||||
|
const_cast<uchar *>(ref.get_ref()));
|
||||||
|
|
||||||
|
String buf, *vec;
|
||||||
|
vec= vect_field->val_str(&buf);
|
||||||
|
|
||||||
|
// TODO(cvicentiu) error checking
|
||||||
|
new_vector->init(ref.get_ref(), ref.get_ref_len(),
|
||||||
|
reinterpret_cast<const float *>(vec->ptr()),
|
||||||
|
vec->length() / sizeof(float));
|
||||||
|
|
||||||
|
all_vector_set.insert(new_vector);
|
||||||
|
|
||||||
|
return new_vector;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b)
|
||||||
|
{
|
||||||
|
double a_dist= reference->distance_to(*a);
|
||||||
|
double b_dist= reference->distance_to(*b);
|
||||||
|
|
||||||
|
if (a_dist < b_dist)
|
||||||
|
return -1;
|
||||||
|
if (a_dist > b_dist)
|
||||||
|
return 1;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool KEEP_PRUNED_CONNECTIONS=true;
|
||||||
|
const bool EXTEND_CANDIDATES=true;
|
||||||
|
|
||||||
|
static bool get_neighbours(TABLE *graph,
|
||||||
|
size_t layer_number,
|
||||||
|
const FVectorRef &source_node,
|
||||||
|
List<FVectorRef> *neighbours);
|
||||||
|
|
||||||
|
static bool select_neighbours(TABLE *source, TABLE *graph,
|
||||||
|
Field *vect_field,
|
||||||
|
size_t layer_number,
|
||||||
|
const FVector &target,
|
||||||
|
const List<FVectorRef> &candidates,
|
||||||
|
size_t max_neighbour_connections,
|
||||||
|
List<FVectorRef> *neighbours)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
TODO: If the input neighbours list is already sorted in search_layer, then
|
||||||
|
no need to do additional queue build steps here.
|
||||||
|
*/
|
||||||
|
|
||||||
|
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin,
|
||||||
|
1000, 0, 0,
|
||||||
|
(my_hash_get_key)FVectorRef::get_key,
|
||||||
|
NULL,
|
||||||
|
HASH_UNIQUE);
|
||||||
|
|
||||||
|
Queue<FVector, const FVector> pq;
|
||||||
|
Queue<FVector, const FVector> pq_discard;
|
||||||
|
Queue<FVector, const FVector> best;
|
||||||
|
|
||||||
|
// TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size.
|
||||||
|
// This should not be fixed.
|
||||||
|
pq.init(10000, 0, cmp_vec, &target);
|
||||||
|
pq_discard.init(10000, 0, cmp_vec, &target);
|
||||||
|
best.init(max_neighbour_connections, true, cmp_vec, &target);
|
||||||
|
|
||||||
|
// TODO(cvicentiu) error checking.
|
||||||
|
for (const FVectorRef &candidate : candidates)
|
||||||
|
{
|
||||||
|
pq.push(FVector::get_fvector_from_source(source, vect_field, candidate));
|
||||||
|
visited.insert(&candidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (EXTEND_CANDIDATES)
|
||||||
|
{
|
||||||
|
for (const FVectorRef &candidate : candidates)
|
||||||
|
{
|
||||||
|
List<FVectorRef> candidate_neighbours;
|
||||||
|
get_neighbours(graph, layer_number, candidate, &candidate_neighbours);
|
||||||
|
for (const FVectorRef &extra_candidate : candidate_neighbours)
|
||||||
|
{
|
||||||
|
if (visited.find(&extra_candidate))
|
||||||
|
continue;
|
||||||
|
visited.insert(&extra_candidate);
|
||||||
|
pq.push(FVector::get_fvector_from_source(source,
|
||||||
|
vect_field,
|
||||||
|
extra_candidate));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DBUG_ASSERT(pq.elements());
|
||||||
|
best.push(pq.pop());
|
||||||
|
|
||||||
|
double best_top = best.top()->distance_to(target);
|
||||||
|
while (pq.elements() && best.elements() < max_neighbour_connections)
|
||||||
|
{
|
||||||
|
const FVector *vec= pq.pop();
|
||||||
|
double cur_dist = vec->distance_to(target);
|
||||||
|
// TODO(cvicentiu) best distance can be cached.
|
||||||
|
if (cur_dist < best_top) {
|
||||||
|
|
||||||
|
best.push(vec);
|
||||||
|
best_top = cur_dist;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
pq_discard.push(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (KEEP_PRUNED_CONNECTIONS)
|
||||||
|
{
|
||||||
|
while (pq_discard.elements() &&
|
||||||
|
best.elements() < max_neighbour_connections)
|
||||||
|
{
|
||||||
|
best.push(pq_discard.pop());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DBUG_ASSERT(best.elements() <= max_neighbour_connections);
|
||||||
|
while (best.elements()) {
|
||||||
|
neighbours->push_front(best.pop());
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//static bool select_neighbours(TABLE *source, TABLE *graph,
|
||||||
|
// Field *vect_field,
|
||||||
|
// size_t layer_number,
|
||||||
|
// const FVector &target,
|
||||||
|
// const List<FVectorRef> &candidates,
|
||||||
|
// size_t max_neighbour_connections,
|
||||||
|
// List<FVectorRef> *neighbours)
|
||||||
|
//{
|
||||||
|
// /*
|
||||||
|
// TODO: If the input neighbours list is already sorted in search_layer, then
|
||||||
|
// no need to do additional queue build steps here.
|
||||||
|
// */
|
||||||
|
//
|
||||||
|
// Queue<FVector, const FVector> pq;
|
||||||
|
// pq.init(candidates.elements, 0, 0, cmp_vec, &target);
|
||||||
|
//
|
||||||
|
// // TODO(cvicentiu) error checking.
|
||||||
|
// for (const FVectorRef &candidate : candidates)
|
||||||
|
// pq.push(FVector::get_fvector_from_source(source, vect_field, candidate));
|
||||||
|
//
|
||||||
|
// for (size_t i = 0; i < max_neighbour_connections; i++)
|
||||||
|
// {
|
||||||
|
// if (!pq.elements())
|
||||||
|
// break;
|
||||||
|
// neighbours->push_back(pq.pop());
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// return false;
|
||||||
|
//}
|
||||||
|
|
||||||
|
|
||||||
|
static void dbug_print_vec_ref(const char *prefix,
|
||||||
|
uint layer,
|
||||||
|
const FVectorRef &ref)
|
||||||
|
{
|
||||||
|
#ifndef DBUG_OFF
|
||||||
|
// TODO(cvicentiu) disable this in release build.
|
||||||
|
char *ref_str= (char *)alloca(ref.get_ref_len() * 2 + 1);
|
||||||
|
DBUG_ASSERT(ref_str);
|
||||||
|
char *ptr= ref_str;
|
||||||
|
for (size_t i = 0; i < ref.get_ref_len(); ptr += 2, i++)
|
||||||
|
{
|
||||||
|
snprintf(ptr, 3, "%02x", ref.get_ref()[i]);
|
||||||
|
}
|
||||||
|
DBUG_PRINT("VECTOR", ("%s %u %s", prefix, layer, ref_str));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dbug_print_vec_neigh(uint layer,
|
||||||
|
const List<FVectorRef> &neighbors)
|
||||||
|
{
|
||||||
|
#ifndef DBUG_OFF
|
||||||
|
DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements));
|
||||||
|
for (const FVectorRef& ref : neighbors)
|
||||||
|
{
|
||||||
|
dbug_print_vec_ref("NEIGH: ", layer, ref);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool write_neighbours(TABLE *graph,
|
||||||
|
size_t layer_number,
|
||||||
|
const FVectorRef &source_node,
|
||||||
|
const List<FVectorRef> &new_neighbours)
|
||||||
|
{
|
||||||
|
DBUG_ASSERT(new_neighbours.elements <= HNSW_MAX_M);
|
||||||
|
|
||||||
|
|
||||||
|
size_t total_size= sizeof(uint16_t) +
|
||||||
|
new_neighbours.elements * source_node.get_ref_len();
|
||||||
|
|
||||||
|
// Allocate memory for the struct and the flexible array member
|
||||||
|
char *neighbor_array_bytes= static_cast<char *>(alloca(total_size));
|
||||||
|
|
||||||
|
DBUG_ASSERT(new_neighbours.elements <= INT16_MAX);
|
||||||
|
*(uint16_t *) neighbor_array_bytes= new_neighbours.elements;
|
||||||
|
char *pos= neighbor_array_bytes + sizeof(uint16_t);
|
||||||
|
for (const auto &node: new_neighbours)
|
||||||
|
{
|
||||||
|
DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len());
|
||||||
|
memcpy(pos, node.get_ref(), node.get_ref_len());
|
||||||
|
pos+= node.get_ref_len();
|
||||||
|
}
|
||||||
|
|
||||||
|
graph->field[0]->store(layer_number);
|
||||||
|
graph->field[1]->store_binary(
|
||||||
|
reinterpret_cast<const char *>(source_node.get_ref()),
|
||||||
|
source_node.get_ref_len());
|
||||||
|
graph->field[2]->set_null();
|
||||||
|
|
||||||
|
|
||||||
|
uchar *key= (uchar*)alloca(graph->key_info->key_length);
|
||||||
|
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
|
||||||
|
|
||||||
|
int err= graph->file->ha_index_read_map(graph->record[1], key,
|
||||||
|
HA_WHOLE_KEY,
|
||||||
|
HA_READ_KEY_EXACT);
|
||||||
|
|
||||||
|
// no record
|
||||||
|
if (err == HA_ERR_KEY_NOT_FOUND)
|
||||||
|
{
|
||||||
|
dbug_print_vec_ref("INSERT ", layer_number, source_node);
|
||||||
|
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
|
||||||
|
graph->file->ha_write_row(graph->record[0]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
dbug_print_vec_ref("UPDATE ", layer_number, source_node);
|
||||||
|
dbug_print_vec_neigh(layer_number, new_neighbours);
|
||||||
|
|
||||||
|
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
|
||||||
|
graph->file->ha_update_row(graph->record[1], graph->record[0]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static bool get_neighbours(TABLE *graph,
|
||||||
|
size_t layer_number,
|
||||||
|
const FVectorRef &source_node,
|
||||||
|
List<FVectorRef> *neighbours)
|
||||||
|
{
|
||||||
|
// TODO(cvicentiu) This allocation need not happen in this function.
|
||||||
|
uchar *key= (uchar*)alloca(graph->key_info->key_length);
|
||||||
|
|
||||||
|
graph->field[0]->store(layer_number);
|
||||||
|
graph->field[1]->store_binary(
|
||||||
|
reinterpret_cast<const char *>(source_node.get_ref()),
|
||||||
|
source_node.get_ref_len());
|
||||||
|
graph->field[2]->set_null();
|
||||||
|
key_copy(key, graph->record[0],
|
||||||
|
graph->key_info, graph->key_info->key_length);
|
||||||
|
if ((graph->file->ha_index_read_map(graph->record[0], key,
|
||||||
|
HA_WHOLE_KEY,
|
||||||
|
HA_READ_KEY_EXACT)))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
//TODO This does two memcpys, one should use str's buffer.
|
||||||
|
String strbuf;
|
||||||
|
String *str= graph->field[2]->val_str(&strbuf);
|
||||||
|
|
||||||
|
// All ref should have same length
|
||||||
|
uint ref_length= source_node.get_ref_len();
|
||||||
|
|
||||||
|
const uchar *neigh_arr_bytes= reinterpret_cast<const uchar *>(str->ptr());
|
||||||
|
uint16_t number_of_neighbours=
|
||||||
|
*reinterpret_cast<const uint16_t*>(neigh_arr_bytes);
|
||||||
|
if (number_of_neighbours != (str->length() - sizeof(uint16_t)) / ref_length)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
neighbours number does not match the data length,
|
||||||
|
should not happen, possible corrupted HNSW index
|
||||||
|
*/
|
||||||
|
DBUG_ASSERT(0); // TODO(cvicentiu) remove this after testing.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uchar *pos = neigh_arr_bytes + sizeof(uint16_t);
|
||||||
|
for (uint16_t i= 0; i < number_of_neighbours; i++)
|
||||||
|
{
|
||||||
|
neighbours->push_back(FVector::get_fvector_ref(pos, ref_length));
|
||||||
|
pos+= ref_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static bool update_second_degree_neighbors(TABLE *source,
|
||||||
|
Field *vec_field,
|
||||||
|
TABLE *graph,
|
||||||
|
size_t layer_number,
|
||||||
|
uint max_neighbours,
|
||||||
|
const FVectorRef &source_node,
|
||||||
|
const List<FVectorRef> &neighbours)
|
||||||
|
{
|
||||||
|
//dbug_print_vec_ref("Updating second degree neighbours", layer_number, source_node);
|
||||||
|
//dbug_print_vec_neigh(layer_number, neighbours);
|
||||||
|
for (const FVectorRef &neigh: neighbours)
|
||||||
|
{
|
||||||
|
List<FVectorRef> new_neighbours;
|
||||||
|
get_neighbours(graph, layer_number, neigh, &new_neighbours);
|
||||||
|
new_neighbours.push_back(&source_node);
|
||||||
|
write_neighbours(graph, layer_number, neigh, new_neighbours);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const FVectorRef &neigh: neighbours)
|
||||||
|
{
|
||||||
|
List<FVectorRef> new_neighbours;
|
||||||
|
get_neighbours(graph, layer_number, neigh, &new_neighbours);
|
||||||
|
// TODO(cvicentiu) get_fvector_from_source results must not need to be freed.
|
||||||
|
FVector *neigh_vec = FVector::get_fvector_from_source(source, vec_field, neigh);
|
||||||
|
|
||||||
|
if (new_neighbours.elements > max_neighbours)
|
||||||
|
{
|
||||||
|
// shrink the neighbours
|
||||||
|
List<FVectorRef> selected;
|
||||||
|
select_neighbours(source, graph, vec_field, layer_number,
|
||||||
|
*neigh_vec, new_neighbours,
|
||||||
|
max_neighbours, &selected);
|
||||||
|
write_neighbours(graph, layer_number, neigh, selected);
|
||||||
|
}
|
||||||
|
|
||||||
|
// release memory
|
||||||
|
new_neighbours.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static bool update_neighbours(TABLE *source,
|
||||||
|
TABLE *graph,
|
||||||
|
Field *vec_field,
|
||||||
|
size_t layer_number,
|
||||||
|
uint max_neighbours,
|
||||||
|
const FVectorRef &source_node,
|
||||||
|
const List<FVectorRef> &neighbours)
|
||||||
|
{
|
||||||
|
// 1. update node's neighbours
|
||||||
|
write_neighbours(graph, layer_number, source_node, neighbours);
|
||||||
|
// 2. update node's neighbours' neighbours (shrink before update)
|
||||||
|
update_second_degree_neighbors(source, vec_field, graph, layer_number,
|
||||||
|
max_neighbours, source_node, neighbours);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static bool search_layer(TABLE *source,
|
||||||
|
TABLE *graph,
|
||||||
|
Field *vec_field,
|
||||||
|
const FVector &target,
|
||||||
|
const List<FVectorRef> &start_nodes,
|
||||||
|
uint max_candidates_return,
|
||||||
|
size_t layer,
|
||||||
|
List<FVectorRef> *result)
|
||||||
|
{
|
||||||
|
DBUG_ASSERT(start_nodes.elements > 0);
|
||||||
|
// Result list must be empty, otherwise there's a risk of memory leak
|
||||||
|
DBUG_ASSERT(result->elements == 0);
|
||||||
|
|
||||||
|
Queue<FVector, const FVector> candidates;
|
||||||
|
Queue<FVector, const FVector> best;
|
||||||
|
//TODO(cvicentiu) Fix this hash method.
|
||||||
|
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin,
|
||||||
|
1000, 0, 0,
|
||||||
|
(my_hash_get_key)FVectorRef::get_key,
|
||||||
|
NULL,
|
||||||
|
HASH_UNIQUE);
|
||||||
|
|
||||||
|
candidates.init(10000, false, cmp_vec, &target);
|
||||||
|
best.init(max_candidates_return, true, cmp_vec, &target);
|
||||||
|
|
||||||
|
for (const FVectorRef &node : start_nodes)
|
||||||
|
{
|
||||||
|
FVector *v= FVector::get_fvector_from_source(source, vec_field, node);
|
||||||
|
candidates.push(v);
|
||||||
|
if (best.elements() < max_candidates_return)
|
||||||
|
best.push(v);
|
||||||
|
else if (target.distance_to(*v) > target.distance_to(*best.top())) {
|
||||||
|
best.replace_top(v);
|
||||||
|
}
|
||||||
|
visited.insert(v);
|
||||||
|
dbug_print_vec_ref("INSERTING node in visited: ", layer, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
double furthest_best = target.distance_to(*best.top());
|
||||||
|
while (candidates.elements())
|
||||||
|
{
|
||||||
|
const FVector &cur_vec= *candidates.pop();
|
||||||
|
double cur_distance = target.distance_to(cur_vec);
|
||||||
|
if (cur_distance > furthest_best && best.elements() == max_candidates_return)
|
||||||
|
{
|
||||||
|
break; // All possible candidates are worse than what we have.
|
||||||
|
// Can't get better.
|
||||||
|
}
|
||||||
|
|
||||||
|
List<FVectorRef> neighbours;
|
||||||
|
get_neighbours(graph, layer, cur_vec, &neighbours);
|
||||||
|
|
||||||
|
for (const FVectorRef &neigh: neighbours)
|
||||||
|
{
|
||||||
|
if (visited.find(&neigh))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
FVector *clone = FVector::get_fvector_from_source(source, vec_field, neigh);
|
||||||
|
// TODO(cvicentiu) mem ownershipw...
|
||||||
|
visited.insert(clone);
|
||||||
|
if (best.elements() < max_candidates_return)
|
||||||
|
{
|
||||||
|
candidates.push(clone);
|
||||||
|
best.push(clone);
|
||||||
|
furthest_best = target.distance_to(*best.top());
|
||||||
|
}
|
||||||
|
else if (target.distance_to(*clone) < furthest_best)
|
||||||
|
{
|
||||||
|
best.replace_top(clone);
|
||||||
|
candidates.push(clone);
|
||||||
|
furthest_best = target.distance_to(*best.top());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
neighbours.empty();
|
||||||
|
}
|
||||||
|
DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements()));
|
||||||
|
|
||||||
|
while (best.elements())
|
||||||
|
{
|
||||||
|
// TODO(cvicentiu) FVector memory leak.
|
||||||
|
// TODO(cvicentiu) this is n*log(n), we need a queue iterator.
|
||||||
|
result->push_front(best.pop());
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::mt19937 gen(42); // TODO(cvicentiu) seeded with 42 for now, this should
|
||||||
|
// use a rnd service
|
||||||
|
|
||||||
int mhnsw_insert(TABLE *table, KEY *keyinfo)
|
int mhnsw_insert(TABLE *table, KEY *keyinfo)
|
||||||
{
|
{
|
||||||
TABLE *graph= table->hlindex;
|
TABLE *graph= table->hlindex;
|
||||||
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
|
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
|
||||||
Field *field= keyinfo->key_part->field;
|
Field *vec_field= keyinfo->key_part->field;
|
||||||
String buf, *res= field->val_str(&buf);
|
String buf, *res= vec_field->val_str(&buf);
|
||||||
handler *h= table->file;
|
handler *h= table->file;
|
||||||
int err= 0;
|
int err= 0;
|
||||||
dbug_tmp_restore_column_map(&table->read_set, old_map);
|
|
||||||
|
|
||||||
/* metadata are checked on open */
|
/* metadata are checked on open */
|
||||||
DBUG_ASSERT(graph);
|
DBUG_ASSERT(graph);
|
||||||
DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR);
|
DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR);
|
||||||
DBUG_ASSERT(keyinfo->usable_key_parts == 1);
|
DBUG_ASSERT(keyinfo->usable_key_parts == 1);
|
||||||
DBUG_ASSERT(field->binary());
|
DBUG_ASSERT(vec_field->binary());
|
||||||
DBUG_ASSERT(field->cmp_type() == STRING_RESULT);
|
DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT);
|
||||||
DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL
|
DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL
|
||||||
DBUG_ASSERT(h->ref_length <= graph->field[0]->field_length);
|
|
||||||
DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length);
|
DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length);
|
||||||
|
DBUG_ASSERT(h->ref_length <= graph->field[2]->field_length);
|
||||||
|
|
||||||
if (res->length() == 0 || res->length() % 4)
|
if (res->length() == 0 || res->length() % 4)
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
// let's do every node to every node
|
const double NORMALIZATION_FACTOR = 1 / std::log(1.0 *
|
||||||
h->position(table->record[0]);
|
table->in_use->variables.hnsw_max_connection_per_layer);
|
||||||
graph->field[0]->store(1);
|
|
||||||
store_ref(table, h, 0);
|
|
||||||
|
|
||||||
if (h->lookup_handler->ha_rnd_init(1))
|
if ((err= h->ha_rnd_init(1)))
|
||||||
return 1;
|
return err;
|
||||||
while (! ((err= h->lookup_handler->ha_rnd_next(h->lookup_buffer))))
|
|
||||||
{
|
|
||||||
h->lookup_handler->position(h->lookup_buffer);
|
|
||||||
if (graph->field[0]->cmp(h->lookup_handler->ref) == 0)
|
|
||||||
continue;
|
|
||||||
store_ref(table, h->lookup_handler, 1);
|
|
||||||
if ((err= graph->file->ha_write_row(graph->record[0])))
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
h->lookup_handler->ha_rnd_end();
|
|
||||||
|
|
||||||
return err == HA_ERR_END_OF_FILE ? 0 : err;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Node
|
|
||||||
{
|
|
||||||
float distance;
|
|
||||||
uchar ref[1000];
|
|
||||||
};
|
|
||||||
|
|
||||||
static int cmp_float(void *, const Node *a, const Node *b)
|
|
||||||
{
|
|
||||||
return a->distance < b->distance ? -1 : a->distance == b->distance ? 0 : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int mhnsw_read_first(TABLE *table, Item *dist, ulonglong limit)
|
|
||||||
{
|
|
||||||
TABLE *graph= table->hlindex;
|
|
||||||
Queue<Node> todo, result;
|
|
||||||
Node *cur;
|
|
||||||
String *str, strbuf;
|
|
||||||
const size_t ref_length= table->file->ref_length;
|
|
||||||
const size_t element_size= ref_length + sizeof(float);
|
|
||||||
uchar *key= (uchar*)alloca(ref_length + 32);
|
|
||||||
Hash_set<Node> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, limit,
|
|
||||||
sizeof(float), ref_length, 0, 0, HASH_UNIQUE);
|
|
||||||
uint keylen;
|
|
||||||
int err= 0;
|
|
||||||
|
|
||||||
DBUG_ASSERT(graph);
|
|
||||||
|
|
||||||
if (todo.init(1000, 0, cmp_float)) // XXX + autoextent
|
|
||||||
return HA_ERR_OUT_OF_MEM;
|
|
||||||
|
|
||||||
if (result.init(limit, 1, cmp_float))
|
|
||||||
return HA_ERR_OUT_OF_MEM;
|
|
||||||
|
|
||||||
if ((err= graph->file->ha_index_init(0, 1)))
|
if ((err= graph->file->ha_index_init(0, 1)))
|
||||||
return err;
|
return err;
|
||||||
|
|
||||||
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
|
longlong max_layer;
|
||||||
|
|
||||||
// 1. read a start row
|
|
||||||
if ((err= graph->file->ha_index_last(graph->record[0])))
|
if ((err= graph->file->ha_index_last(graph->record[0])))
|
||||||
return err;
|
|
||||||
|
|
||||||
if (!(str= graph->field[0]->val_str(&strbuf)))
|
|
||||||
return HA_ERR_CRASHED;
|
|
||||||
|
|
||||||
DBUG_ASSERT(str->length() == ref_length);
|
|
||||||
|
|
||||||
cur= (Node*)table->in_use->alloc(element_size);
|
|
||||||
memcpy(cur->ref, str->ptr(), ref_length);
|
|
||||||
|
|
||||||
if ((err= table->file->ha_rnd_init(0)))
|
|
||||||
return err;
|
|
||||||
|
|
||||||
if ((err= table->file->ha_rnd_pos(table->record[0], cur->ref)))
|
|
||||||
return HA_ERR_CRASHED;
|
|
||||||
|
|
||||||
// 2. add it to the todo
|
|
||||||
cur->distance= dist->val_real();
|
|
||||||
if (dist->is_null())
|
|
||||||
return HA_ERR_END_OF_FILE;
|
|
||||||
todo.push(cur);
|
|
||||||
visited.insert(cur);
|
|
||||||
|
|
||||||
while (todo.elements())
|
|
||||||
{
|
{
|
||||||
// 3. pick the top node from the todo
|
if (err != HA_ERR_END_OF_FILE)
|
||||||
cur= todo.pop();
|
|
||||||
|
|
||||||
// 4. add it to the result
|
|
||||||
if (result.is_full())
|
|
||||||
{
|
{
|
||||||
// 5. if not added, greedy search done
|
graph->file->ha_index_end();
|
||||||
if (cur->distance > result.top()->distance)
|
return err;
|
||||||
break;
|
}
|
||||||
result.replace_top(cur);
|
// First insert!
|
||||||
|
h->position(table->record[0]);
|
||||||
|
write_neighbours(graph, 0, {h->ref, h->ref_length}, {});
|
||||||
|
|
||||||
|
h->ha_rnd_end();
|
||||||
|
graph->file->ha_index_end();
|
||||||
|
return 0; // TODO (error during store_link)
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
result.push(cur);
|
max_layer= graph->field[0]->val_int();
|
||||||
|
|
||||||
float threshold= result.is_full() ? result.top()->distance : FLT_MAX;
|
FVector target;
|
||||||
|
h->position(table->record[0]);
|
||||||
|
// TODO (cvicentiu) Error checking.
|
||||||
|
target.init(h->ref, h->ref_length,
|
||||||
|
reinterpret_cast<const float *>(res->ptr()),
|
||||||
|
res->length() / sizeof(float));
|
||||||
|
|
||||||
// 6. add all its [yet unvisited] neighbours to the todo heap
|
|
||||||
keylen= graph->field[0]->get_key_image(key, ref_length, Field::itRAW);
|
|
||||||
if ((err= graph->file->ha_index_read_map(graph->record[0], key, 3,
|
|
||||||
HA_READ_KEY_EXACT)))
|
|
||||||
return HA_ERR_CRASHED;
|
|
||||||
|
|
||||||
do {
|
std::uniform_real_distribution<> dis(0.0, 1.0);
|
||||||
if (!(str= graph->field[1]->val_str(&strbuf)))
|
double new_num= dis(gen);
|
||||||
return HA_ERR_CRASHED;
|
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
|
||||||
|
longlong new_node_layer= std::floor(log);
|
||||||
|
|
||||||
if (visited.find(str->ptr(), ref_length))
|
List<FVectorRef> start_nodes;
|
||||||
continue;
|
|
||||||
|
|
||||||
if ((err= table->file->ha_rnd_pos(table->record[0], (uchar*)str->ptr())))
|
String ref_str, *ref_ptr;
|
||||||
return HA_ERR_CRASHED;
|
ref_ptr= graph->field[1]->val_str(&ref_str);
|
||||||
|
|
||||||
float distance= dist->val_real();
|
FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()};
|
||||||
if (distance > threshold)
|
//FVector *start_node= start_node_ref.get_fvector_from_source(table, vec_field);
|
||||||
continue;
|
|
||||||
|
|
||||||
cur= (Node*)table->in_use->alloc(element_size);
|
// TODO(cvicentiu) error checking. Also make sure we use a random start node
|
||||||
cur->distance= distance;
|
// in last layer.
|
||||||
memcpy(cur->ref, str->ptr(), ref_length);
|
start_nodes.push_back(&start_node_ref);
|
||||||
todo.push(cur);
|
// TODO start_nodes needs to have one element in it.
|
||||||
visited.insert(cur);
|
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
|
||||||
} while (!graph->file->ha_index_next_same(graph->record[0], key, keylen));
|
{
|
||||||
// 7. goto 3
|
List<FVectorRef> candidates;
|
||||||
|
search_layer(table, graph, vec_field, target, start_nodes,
|
||||||
|
table->in_use->variables.hnsw_ef_constructor, cur_layer,
|
||||||
|
&candidates);
|
||||||
|
start_nodes.empty();
|
||||||
|
start_nodes.push_back(candidates.head());
|
||||||
|
//candidates.delete_elements();
|
||||||
|
//TODO(cvicentiu) memory leak
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (longlong cur_layer= std::min(max_layer, new_node_layer);
|
||||||
|
cur_layer >= 0; cur_layer--)
|
||||||
|
{
|
||||||
|
List<FVectorRef> candidates;
|
||||||
|
List<FVectorRef> neighbours;
|
||||||
|
search_layer(table, graph, vec_field, target, start_nodes,
|
||||||
|
table->in_use->variables.hnsw_ef_constructor,
|
||||||
|
cur_layer, &candidates);
|
||||||
|
// release vectors
|
||||||
|
start_nodes.empty();
|
||||||
|
|
||||||
|
uint max_neighbours= (cur_layer == 0) ?
|
||||||
|
table->in_use->variables.hnsw_max_connection_per_layer * 2
|
||||||
|
: table->in_use->variables.hnsw_max_connection_per_layer;
|
||||||
|
|
||||||
|
select_neighbours(table, graph, vec_field, cur_layer,
|
||||||
|
target, candidates,
|
||||||
|
max_neighbours, &neighbours);
|
||||||
|
update_neighbours(table, graph, vec_field, cur_layer, max_neighbours,
|
||||||
|
target, neighbours);
|
||||||
|
start_nodes= candidates;
|
||||||
|
}
|
||||||
|
start_nodes.empty();
|
||||||
|
|
||||||
|
for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer;
|
||||||
|
cur_layer++)
|
||||||
|
{
|
||||||
|
write_neighbours(graph, cur_layer, target, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
h->ha_rnd_end();
|
||||||
|
graph->file->ha_index_end();
|
||||||
|
dbug_tmp_restore_column_map(&table->read_set, old_map);
|
||||||
|
|
||||||
|
return err == HA_ERR_END_OF_FILE ? 0 : err;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
|
||||||
|
{
|
||||||
|
TABLE *graph= table->hlindex;
|
||||||
|
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
|
||||||
|
// TODO(cvicentiu) onlye one hlindex now.
|
||||||
|
Field *vec_field= keyinfo->key_part->field;
|
||||||
|
Item_func_vec_distance *fun= (Item_func_vec_distance *)dist;
|
||||||
|
String buf, *res= fun->arguments()[1]->val_str(&buf);
|
||||||
|
handler *h= table->file;
|
||||||
|
|
||||||
|
//TODO(scope_exit)
|
||||||
|
int err;
|
||||||
|
if ((err= h->ha_rnd_init(0)))
|
||||||
|
return err;
|
||||||
|
|
||||||
|
if ((err= graph->file->ha_index_init(0, 1)))
|
||||||
|
return err;
|
||||||
|
|
||||||
|
h->position(table->record[0]);
|
||||||
|
FVector target;
|
||||||
|
target.init(h->ref,
|
||||||
|
h->ref_length,
|
||||||
|
reinterpret_cast<const float *>(res->ptr()),
|
||||||
|
res->length() / sizeof(float));
|
||||||
|
|
||||||
|
List<FVectorRef> candidates;
|
||||||
|
List<FVectorRef> start_nodes;
|
||||||
|
|
||||||
|
longlong max_layer;
|
||||||
|
if ((err= graph->file->ha_index_last(graph->record[0])))
|
||||||
|
{
|
||||||
|
if (err != HA_ERR_END_OF_FILE)
|
||||||
|
{
|
||||||
|
graph->file->ha_index_end();
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
h->ha_rnd_end();
|
||||||
|
graph->file->ha_index_end();
|
||||||
|
return 0; // TODO (error during store_link)
|
||||||
|
}
|
||||||
|
else
|
||||||
|
max_layer= graph->field[0]->val_int();
|
||||||
|
|
||||||
|
String ref_str, *ref_ptr;
|
||||||
|
ref_ptr= graph->field[1]->val_str(&ref_str);
|
||||||
|
FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()};
|
||||||
|
// TODO(cvicentiu) error checking. Also make sure we use a random start node
|
||||||
|
// in last layer.
|
||||||
|
start_nodes.push_back(&start_node_ref);
|
||||||
|
|
||||||
|
ulonglong ef_search= MY_MAX(
|
||||||
|
table->in_use->variables.hnsw_ef_search, limit);
|
||||||
|
|
||||||
|
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
|
||||||
|
{
|
||||||
|
search_layer(table, graph, vec_field, target, start_nodes, ef_search,
|
||||||
|
cur_layer, &candidates);
|
||||||
|
start_nodes.empty();
|
||||||
|
//start_nodes.delete_elements();
|
||||||
|
start_nodes.push_back(candidates.head());
|
||||||
|
//candidates.delete_elements();
|
||||||
|
candidates.empty();
|
||||||
|
//TODO(cvicentiu) memleak.
|
||||||
|
}
|
||||||
|
|
||||||
|
search_layer(table, graph, vec_field, target, start_nodes,
|
||||||
|
ef_search, 0, &candidates);
|
||||||
|
|
||||||
// 8. return results
|
// 8. return results
|
||||||
Node **context= (Node**)table->in_use->alloc(sizeof(Node**)*result.elements()+1);
|
FVectorRef **context= (FVectorRef**)table->in_use->alloc(
|
||||||
|
sizeof(FVectorRef*) * (limit + 1));
|
||||||
graph->context= context;
|
graph->context= context;
|
||||||
|
|
||||||
Node **ptr= context+result.elements();
|
FVectorRef **ptr= context;
|
||||||
*ptr= 0;
|
while (limit--)
|
||||||
while (result.elements())
|
*ptr++= candidates.pop();
|
||||||
*--ptr= result.pop();
|
*ptr= nullptr;
|
||||||
|
|
||||||
return mhnsw_read_next(table);
|
err= mhnsw_read_next(table);
|
||||||
|
graph->file->ha_index_end();
|
||||||
|
|
||||||
|
// TODO release vectors after query
|
||||||
|
|
||||||
|
dbug_tmp_restore_column_map(&table->read_set, old_map);
|
||||||
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
int mhnsw_read_next(TABLE *table)
|
int mhnsw_read_next(TABLE *table)
|
||||||
{
|
{
|
||||||
Node ***context= (Node***)&table->hlindex->context;
|
FVectorRef ***context= (FVectorRef ***)&table->hlindex->context;
|
||||||
if (**context)
|
FVectorRef *cur_vec= **context;
|
||||||
return table->file->ha_rnd_pos(table->record[0], (*(*context)++)->ref);
|
if (cur_vec)
|
||||||
|
{
|
||||||
|
int err= table->file->ha_rnd_pos(table->record[0],
|
||||||
|
(uchar *)(cur_vec)->get_ref());
|
||||||
|
// release vectors
|
||||||
|
// delete cur_vec;
|
||||||
|
|
||||||
|
(*context)++;
|
||||||
|
return err;
|
||||||
|
}
|
||||||
return HA_ERR_END_OF_FILE;
|
return HA_ERR_END_OF_FILE;
|
||||||
}
|
}
|
||||||
|
@@ -15,10 +15,14 @@
|
|||||||
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA
|
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <my_global.h>
|
||||||
|
#include "item.h"
|
||||||
|
#include "m_string.h"
|
||||||
|
#include "structs.h"
|
||||||
#include "table.h"
|
#include "table.h"
|
||||||
|
|
||||||
extern const LEX_CSTRING mhnsw_hlindex_table;
|
extern const LEX_CSTRING mhnsw_hlindex_table;
|
||||||
|
|
||||||
int mhnsw_insert(TABLE *table, KEY *keyinfo);
|
int mhnsw_insert(TABLE *table, KEY *keyinfo);
|
||||||
int mhnsw_read_first(TABLE *table, Item *dist, ulonglong limit);
|
int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit);
|
||||||
int mhnsw_read_next(TABLE *table);
|
int mhnsw_read_next(TABLE *table);
|
||||||
|
Reference in New Issue
Block a user