1
0
mirror of https://github.com/redis/node-redis.git synced 2025-08-07 13:22:56 +03:00

Support Vector Similarity (#1785)

* ft.alter

* support paramas

* remove only and skip

* merge

* fix imports

* add Vector field

* update version

* push attributes

* typo

* test

* version check

* remove .only

* remove unued import

* add support for DIALECT

* clean code

Co-authored-by: Avital-Fine <avital.fine@redis.com>
Co-authored-by: leibale <leibale1998@gmail.com>
This commit is contained in:
Avital Fine
2022-03-31 13:13:06 +02:00
committed by GitHub
parent 33a3f3f6c6
commit 4683e969b8
16 changed files with 631 additions and 304 deletions

View File

@@ -28,7 +28,7 @@ import * as SUGLEN from './SUGLEN';
import * as SYNDUMP from './SYNDUMP';
import * as SYNUPDATE from './SYNUPDATE';
import * as TAGVALS from './TAGVALS';
import { RedisCommandArguments } from '@node-redis/client/dist/lib/commands';
import { RedisCommandArgument, RedisCommandArguments } from '@node-redis/client/dist/lib/commands';
import { pushOptionalVerdictArgument, pushVerdictArgument } from '@node-redis/client/dist/lib/commands/generic-transformers';
import { SearchOptions } from './SEARCH';
@@ -172,16 +172,29 @@ export enum SchemaFieldTypes {
TEXT = 'TEXT',
NUMERIC = 'NUMERIC',
GEO = 'GEO',
TAG = 'TAG'
TAG = 'TAG',
VECTOR = 'VECTOR'
}
type CreateSchemaField<T extends SchemaFieldTypes, E = Record<keyof any, any>> = T | ({
type CreateSchemaField<
T extends SchemaFieldTypes,
E = Record<keyof any, any>
> = T | ({
type: T;
AS?: string;
SORTABLE?: true | 'UNF';
NOINDEX?: true;
} & E);
type CreateSchemaCommonField<
T extends SchemaFieldTypes,
E = Record<string, never>
> = CreateSchemaField<
T,
({
SORTABLE?: true | 'UNF';
NOINDEX?: true;
} & E)
>;
export enum SchemaTextFieldPhonetics {
DM_EN = 'dm:en',
DM_FR = 'dm:fr',
@@ -189,27 +202,55 @@ export enum SchemaTextFieldPhonetics {
DM_ES = 'dm:es'
}
type CreateSchemaTextField = CreateSchemaField<SchemaFieldTypes.TEXT, {
type CreateSchemaTextField = CreateSchemaCommonField<SchemaFieldTypes.TEXT, {
NOSTEM?: true;
WEIGHT?: number;
PHONETIC?: SchemaTextFieldPhonetics;
}>;
type CreateSchemaNumericField = CreateSchemaField<SchemaFieldTypes.NUMERIC>;
type CreateSchemaNumericField = CreateSchemaCommonField<SchemaFieldTypes.NUMERIC>;
type CreateSchemaGeoField = CreateSchemaField<SchemaFieldTypes.GEO>;
type CreateSchemaGeoField = CreateSchemaCommonField<SchemaFieldTypes.GEO>;
type CreateSchemaTagField = CreateSchemaField<SchemaFieldTypes.TAG, {
type CreateSchemaTagField = CreateSchemaCommonField<SchemaFieldTypes.TAG, {
SEPARATOR?: string;
CASESENSITIVE?: true;
}>;
export enum VectorAlgorithms {
FLAT = 'FLAT',
HNSW = 'HNSW'
}
type CreateSchemaVectorField<
T extends VectorAlgorithms,
A extends Record<string, unknown>
> = CreateSchemaField<SchemaFieldTypes.VECTOR, {
ALGORITHM: T;
TYPE: string;
DIM: number;
DISTANCE_METRIC: 'L2' | 'IP' | 'COSINE';
INITIAL_CAP?: number;
} & A>;
type CreateSchemaFlatVectorField = CreateSchemaVectorField<VectorAlgorithms.FLAT, {
BLOCK_SIZE?: number;
}>;
type CreateSchemaHNSWVectorField = CreateSchemaVectorField<VectorAlgorithms.HNSW, {
M?: number;
EF_CONSTRUCTION?: number;
EF_RUNTIME?: number;
}>;
export interface RediSearchSchema {
[field: string]:
CreateSchemaTextField |
CreateSchemaNumericField |
CreateSchemaGeoField |
CreateSchemaTagField;
CreateSchemaTagField |
CreateSchemaFlatVectorField |
CreateSchemaHNSWVectorField;
}
export function pushSchema(args: RedisCommandArguments, schema: RediSearchSchema) {
@@ -257,6 +298,47 @@ export function pushSchema(args: RedisCommandArguments, schema: RediSearchSchema
}
break;
case SchemaFieldTypes.VECTOR:
args.push(fieldOptions.ALGORITHM);
pushArgumentsWithLength(args, () => {
args.push(
'TYPE', fieldOptions.TYPE,
'DIM', fieldOptions.DIM.toString(),
'DISTANCE_METRIC', fieldOptions.DISTANCE_METRIC
);
if (fieldOptions.INITIAL_CAP) {
args.push('INITIAL_CAP', fieldOptions.INITIAL_CAP.toString());
}
switch (fieldOptions.ALGORITHM) {
case VectorAlgorithms.FLAT:
if (fieldOptions.BLOCK_SIZE) {
args.push('BLOCK_SIZE', fieldOptions.BLOCK_SIZE.toString());
}
break;
case VectorAlgorithms.HNSW:
if (fieldOptions.M) {
args.push('M', fieldOptions.M.toString());
}
if (fieldOptions.EF_CONSTRUCTION) {
args.push('EF_CONSTRUCTION', fieldOptions.EF_CONSTRUCTION.toString());
}
if (fieldOptions.EF_RUNTIME) {
args.push('EF_RUNTIME', fieldOptions.EF_RUNTIME.toString());
}
break;
}
});
continue; // vector fields do not contain SORTABLE and NOINDEX options
}
if (fieldOptions.SORTABLE) {
@@ -273,11 +355,27 @@ export function pushSchema(args: RedisCommandArguments, schema: RediSearchSchema
}
}
export type Params = Record<string, RedisCommandArgument | number>;
export function pushParamsArgs(
args: RedisCommandArguments,
params?: Params
): RedisCommandArguments {
if (params) {
const enrties = Object.entries(params);
args.push('PARAMS', (enrties.length * 2).toString());
for (const [key, value] of enrties) {
args.push(key, value.toString());
}
}
return args;
}
export function pushSearchOptions(
args: RedisCommandArguments,
options?: SearchOptions
): RedisCommandArguments {
if (options?.VERBATIM) {
args.push('VERBATIM');
}
@@ -381,6 +479,16 @@ export function pushSearchOptions(
);
}
if (options?.PARAMS) {
pushParamsArgs(args, options.PARAMS);
}
if (options?.DIALECT) {
args.push('DIALECT', options.DIALECT.toString());
}
console.log('!@#', args);
return args;
}