diff --git a/packages/client/lib/client/commands-queue.ts b/packages/client/lib/client/commands-queue.ts index ae67ca28cd..9b7f737113 100644 --- a/packages/client/lib/client/commands-queue.ts +++ b/packages/client/lib/client/commands-queue.ts @@ -338,6 +338,10 @@ export default class RedisCommandsQueue { return this.#addPubSubCommand(command); } + removeAllPubSubListeners() { + return this.#pubSub.removeAllListeners(); + } + resubscribe(chainId?: symbol) { const commands = this.#pubSub.resubscribe(); if (!commands.length) return; diff --git a/packages/client/lib/client/index.ts b/packages/client/lib/client/index.ts index 919baab0df..ea2102c37f 100644 --- a/packages/client/lib/client/index.ts +++ b/packages/client/lib/client/index.ts @@ -765,7 +765,7 @@ export default class RedisClient< } }); } - + if (this.#clientSideCache) { commands.push({cmd: this.#clientSideCache.trackingOn()}); } @@ -773,9 +773,9 @@ export default class RedisClient< if (this.#options?.emitInvalidate) { commands.push({cmd: ['CLIENT', 'TRACKING', 'ON']}); } - + const maintenanceHandshakeCmd = await EnterpriseMaintenanceManager.getHandshakeCommand(this.#options); - + if(maintenanceHandshakeCmd) { commands.push(maintenanceHandshakeCmd); }; @@ -818,6 +818,11 @@ export default class RedisClient< chainId = Symbol('Socket Initiator'); const resubscribePromise = this.#queue.resubscribe(chainId); + resubscribePromise?.catch(error => { + if (error.message && error.message.startsWith('MOVED')) { + this.emit('__MOVED', this._self.#queue.removeAllPubSubListeners()); + } + }); if (resubscribePromise) { promises.push(resubscribePromise); } diff --git a/packages/client/lib/client/pub-sub.ts b/packages/client/lib/client/pub-sub.ts index 1387aea841..1895f96a88 100644 --- a/packages/client/lib/client/pub-sub.ts +++ b/packages/client/lib/client/pub-sub.ts @@ -323,27 +323,52 @@ export class PubSub { } resubscribe() { - const commands = []; + const commands: PubSubCommand[] = []; for (const [type, listeners] of Object.entries(this.listeners)) { if (!listeners.size) continue; this.#isActive = true; - this.#subscribing++; - const callback = () => this.#subscribing--; - commands.push({ - args: [ - COMMANDS[type as PubSubType].subscribe, - ...listeners.keys() - ], - channelsCounter: listeners.size, - resolve: callback, - reject: callback - } satisfies PubSubCommand); + + if(type === PUBSUB_TYPE.SHARDED) { + this.#shardedResubscribe(commands, listeners); + } else { + this.#normalResubscribe(commands, type, listeners); + } } return commands; } + #normalResubscribe(commands: PubSubCommand[], type: string, listeners: PubSubTypeListeners) { + this.#subscribing++; + const callback = () => this.#subscribing--; + commands.push({ + args: [ + COMMANDS[type as PubSubType].subscribe, + ...listeners.keys() + ], + channelsCounter: listeners.size, + resolve: callback, + reject: callback + }); + } + + #shardedResubscribe(commands: PubSubCommand[], listeners: PubSubTypeListeners) { + const callback = () => this.#subscribing--; + for(const channel of listeners.keys()) { + this.#subscribing++; + commands.push({ + args: [ + COMMANDS[PUBSUB_TYPE.SHARDED].subscribe, + channel + ], + channelsCounter: 1, + resolve: callback, + reject: callback + }) + } + } + handleMessageReply(reply: Array): boolean { if (COMMANDS[PUBSUB_TYPE.CHANNELS].message.equals(reply[0])) { this.#emitPubSubMessage( @@ -379,6 +404,22 @@ export class PubSub { return listeners; } + removeAllListeners() { + const result = { + [PUBSUB_TYPE.CHANNELS]: this.listeners[PUBSUB_TYPE.CHANNELS], + [PUBSUB_TYPE.PATTERNS]: this.listeners[PUBSUB_TYPE.PATTERNS], + [PUBSUB_TYPE.SHARDED]: this.listeners[PUBSUB_TYPE.SHARDED] + } + + this.#updateIsActive(); + + this.listeners[PUBSUB_TYPE.CHANNELS] = new Map(); + this.listeners[PUBSUB_TYPE.PATTERNS] = new Map(); + this.listeners[PUBSUB_TYPE.SHARDED] = new Map(); + + return result; + } + #emitPubSubMessage( type: PubSubType, message: Buffer, diff --git a/packages/client/lib/cluster/cluster-slots.ts b/packages/client/lib/cluster/cluster-slots.ts index 737413677e..ae81495843 100644 --- a/packages/client/lib/cluster/cluster-slots.ts +++ b/packages/client/lib/cluster/cluster-slots.ts @@ -2,7 +2,7 @@ import { RedisClusterClientOptions, RedisClusterOptions } from '.'; import { RootNodesUnavailableError } from '../errors'; import RedisClient, { RedisClientOptions, RedisClientType } from '../client'; import { EventEmitter } from 'node:stream'; -import { ChannelListeners, PUBSUB_TYPE, PubSubTypeListeners } from '../client/pub-sub'; +import { ChannelListeners, PUBSUB_TYPE, PubSubListeners, PubSubTypeListeners } from '../client/pub-sub'; import { RedisArgument, RedisFunctions, RedisModules, RedisScripts, RespVersions, TypeMapping } from '../RESP/types'; import calculateSlot from 'cluster-key-slot'; import { RedisSocketOptions } from '../client/socket'; @@ -185,6 +185,7 @@ export default class RedisClusterSlots< async #discover(rootNode: RedisClusterClientOptions) { this.clientSideCache?.clear(); this.clientSideCache?.disable(); + try { const addressesInUse = new Set(), promises: Array> = [], @@ -224,6 +225,7 @@ export default class RedisClusterSlots< } } + //Keep only the nodes that are still in use for (const [address, node] of this.nodeByAddress.entries()) { if (addressesInUse.has(address)) continue; @@ -337,23 +339,29 @@ export default class RedisClusterSlots< const socket = this.#getNodeAddress(node.address) ?? { host: node.host, port: node.port, }; - const client = Object.freeze({ + const clientInfo = Object.freeze({ host: socket.host, port: socket.port, }); const emit = this.#emit; - return this.#clientFactory( + const client = this.#clientFactory( this.#clientOptionsDefaults({ clientSideCache: this.clientSideCache, RESP: this.#options.RESP, socket, readonly, })) - .on('error', error => emit('node-error', error, client)) - .on('reconnecting', () => emit('node-reconnecting', client)) - .once('ready', () => emit('node-ready', client)) - .once('connect', () => emit('node-connect', client)) - .once('end', () => emit('node-disconnect', client)); + .on('error', error => emit('node-error', error, clientInfo)) + .on('reconnecting', () => emit('node-reconnecting', clientInfo)) + .once('ready', () => emit('node-ready', clientInfo)) + .once('connect', () => emit('node-connect', clientInfo)) + .once('end', () => emit('node-disconnect', clientInfo)) + .on('__MOVED', async (allPubSubListeners: PubSubListeners) => { + await this.rediscover(client); + this.#emit('__resubscribeAllPubSubListeners', allPubSubListeners); + }); + + return client; } #createNodeClient(node: ShardNode, readonly?: boolean) { @@ -374,7 +382,9 @@ export default class RedisClusterSlots< async rediscover(startWith: RedisClientType): Promise { this.#runningRediscoverPromise ??= this.#rediscover(startWith) - .finally(() => this.#runningRediscoverPromise = undefined); + .finally(() => { + this.#runningRediscoverPromise = undefined + }); return this.#runningRediscoverPromise; } diff --git a/packages/client/lib/cluster/index.ts b/packages/client/lib/cluster/index.ts index 16454e66fb..238f3a5919 100644 --- a/packages/client/lib/cluster/index.ts +++ b/packages/client/lib/cluster/index.ts @@ -6,7 +6,7 @@ import { EventEmitter } from 'node:events'; import { attachConfig, functionArgumentsPrefix, getTransformReply, scriptArgumentsPrefix } from '../commander'; import RedisClusterSlots, { NodeAddressMap, ShardNode } from './cluster-slots'; import RedisClusterMultiCommand, { RedisClusterMultiCommandType } from './multi-command'; -import { PubSubListener } from '../client/pub-sub'; +import { PubSubListener, PubSubListeners } from '../client/pub-sub'; import { ErrorReply } from '../errors'; import { RedisTcpSocketOptions } from '../client/socket'; import { ClientSideCacheConfig, PooledClientSideCacheProvider } from '../client/cache'; @@ -310,6 +310,7 @@ export default class RedisCluster< this._options = options; this._slots = new RedisClusterSlots(options, this.emit.bind(this)); + this.on('__resubscribeAllPubSubListeners', this.resubscribeAllPubSubListeners.bind(this)); if (options?.commandOptions) { this._commandOptions = options.commandOptions; @@ -584,6 +585,33 @@ export default class RedisCluster< ); } + resubscribeAllPubSubListeners(allListeners: PubSubListeners) { + for(const [channel, listeners] of allListeners.CHANNELS) { + listeners.buffers.forEach(bufListener => { + this.subscribe(channel, bufListener, true); + }); + listeners.strings.forEach(strListener => { + this.subscribe(channel, strListener); + }); + }; + for (const [channel, listeners] of allListeners.PATTERNS) { + listeners.buffers.forEach(bufListener => { + this.pSubscribe(channel, bufListener, true); + }); + listeners.strings.forEach(strListener => { + this.pSubscribe(channel, strListener); + }); + }; + for (const [channel, listeners] of allListeners.SHARDED) { + listeners.buffers.forEach(bufListener => { + this.sSubscribe(channel, bufListener, true); + }); + listeners.strings.forEach(strListener => { + this.sSubscribe(channel, strListener); + }); + }; + } + sUnsubscribe = this.SUNSUBSCRIBE; /** diff --git a/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts b/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts new file mode 100644 index 0000000000..46ef252da8 --- /dev/null +++ b/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts @@ -0,0 +1,362 @@ +import type { Cluster, TestConfig } from "./utils/test.util"; +import { createClusterTestClient, getConfig } from "./utils/test.util"; +import { FaultInjectorClient } from "../fault-injector-client"; +import { TestCommandRunner } from "./utils/command-runner"; +import { CHANNELS, CHANNELS_BY_SLOT } from "./utils/test.util"; +import { MessageTracker } from "./utils/message-tracker"; +import assert from "node:assert"; +import { setTimeout } from "node:timers/promises"; + +describe("Sharded Pub/Sub E2E", () => { + let faultInjectorClient: FaultInjectorClient; + let config: TestConfig; + + before(() => { + config = getConfig(); + + faultInjectorClient = new FaultInjectorClient(config.faultInjectorUrl); + }); + + describe("Single Subscriber", () => { + let subscriber: Cluster; + let publisher: Cluster; + let messageTracker: MessageTracker; + + beforeEach(async () => { + messageTracker = new MessageTracker(CHANNELS); + subscriber = createClusterTestClient(config.clientConfig, {}); + publisher = createClusterTestClient(config.clientConfig, {}); + await Promise.all([subscriber.connect(), publisher.connect()]); + }); + + afterEach(async () => { + await Promise.all([subscriber.quit(), publisher.quit()]); + }); + + it("should receive messages published to multiple channels", async () => { + for (const channel of CHANNELS) { + await subscriber.sSubscribe(channel, (_msg, channel) => + messageTracker.incrementReceived(channel), + ); + } + const { controller, result } = + TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker, + ); + // Wait for 10 seconds, while publishing messages + await setTimeout(10_000); + controller.abort(); + await result; + + for (const channel of CHANNELS) { + assert.strictEqual( + messageTracker.getChannelStats(channel)?.received, + messageTracker.getChannelStats(channel)?.sent, + ); + } + }); + + it("should resume publishing and receiving after failover", async () => { + for (const channel of CHANNELS) { + await subscriber.sSubscribe(channel, (_msg, channel) => { + messageTracker.incrementReceived(channel); + }); + } + + // Trigger failover twice + for (let i = 0; i < 2; i++) { + // Start publishing messages + const { controller: publishAbort, result: publishResult } = + TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker, + ); + + // Trigger failover during publishing + const { action_id: failoverActionId } = + await faultInjectorClient.triggerAction({ + type: "failover", + parameters: { + bdb_id: config.clientConfig.bdbId.toString(), + cluster_index: 0, + }, + }); + + // Wait for failover to complete + await faultInjectorClient.waitForAction(failoverActionId); + + publishAbort.abort(); + await publishResult; + + for (const channel of CHANNELS) { + const sent = messageTracker.getChannelStats(channel)!.sent; + const received = messageTracker.getChannelStats(channel)!.received; + + assert.ok( + received <= sent, + `Channel ${channel}: received (${received}) should be <= sent (${sent})`, + ); + } + + // Wait for 2 seconds before resuming publishing + await setTimeout(2_000); + messageTracker.reset(); + + const { + controller: afterFailoverController, + result: afterFailoverResult, + } = TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker, + ); + + await setTimeout(10_000); + afterFailoverController.abort(); + await afterFailoverResult; + + for (const channel of CHANNELS) { + const sent = messageTracker.getChannelStats(channel)!.sent; + const received = messageTracker.getChannelStats(channel)!.received; + assert.ok(sent > 0, `Channel ${channel} should have sent messages`); + assert.ok( + received > 0, + `Channel ${channel} should have received messages`, + ); + assert.strictEqual( + messageTracker.getChannelStats(channel)!.received, + messageTracker.getChannelStats(channel)!.sent, + `Channel ${channel} received (${received}) should equal sent (${sent}) once resumed after failover`, + ); + } + } + }); + + it("should NOT receive messages after sunsubscribe", async () => { + for (const channel of CHANNELS) { + await subscriber.sSubscribe(channel, (_msg, channel) => messageTracker.incrementReceived(channel)); + } + + const { controller, result } = + TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker, + ); + + // Wait for 5 seconds, while publishing messages + await setTimeout(5_000); + controller.abort(); + await result; + + for (const channel of CHANNELS) { + assert.strictEqual( + messageTracker.getChannelStats(channel)?.received, + messageTracker.getChannelStats(channel)?.sent, + ); + } + + // Reset message tracker + messageTracker.reset(); + + const unsubscribeChannels = [ + CHANNELS_BY_SLOT["1000"], + CHANNELS_BY_SLOT["8000"], + CHANNELS_BY_SLOT["16000"], + ]; + + for (const channel of unsubscribeChannels) { + await subscriber.sUnsubscribe(channel); + } + + const { + controller: afterUnsubscribeController, + result: afterUnsubscribeResult, + } = TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker, + ); + + // Wait for 5 seconds, while publishing messages + await setTimeout(5_000); + afterUnsubscribeController.abort(); + await afterUnsubscribeResult; + + for (const channel of unsubscribeChannels) { + assert.strictEqual( + messageTracker.getChannelStats(channel)?.received, + 0, + `Channel ${channel} should not have received messages after unsubscribe`, + ); + } + + // All other channels should have received messages + const stillSubscribedChannels = CHANNELS.filter( + (channel) => !unsubscribeChannels.includes(channel as any), + ); + + for (const channel of stillSubscribedChannels) { + assert.ok( + messageTracker.getChannelStats(channel)!.received > 0, + `Channel ${channel} should have received messages`, + ); + } + }); + }); + + describe("Multiple Subscribers", () => { + let subscriber1: Cluster; + let subscriber2: Cluster; + + let publisher: Cluster; + + let messageTracker1: MessageTracker; + let messageTracker2: MessageTracker; + + beforeEach(async () => { + messageTracker1 = new MessageTracker(CHANNELS); + messageTracker2 = new MessageTracker(CHANNELS); + subscriber1 = createClusterTestClient(config.clientConfig); + subscriber2 = createClusterTestClient(config.clientConfig); + publisher = createClusterTestClient(config.clientConfig); + await Promise.all([ + subscriber1.connect(), + subscriber2.connect(), + publisher.connect(), + ]); + }); + + afterEach(async () => { + await Promise.all([ + subscriber1.quit(), + subscriber2.quit(), + publisher.quit(), + ]); + }); + + it("should receive messages published to multiple channels", async () => { + for (const channel of CHANNELS) { + await subscriber1.sSubscribe(channel, (_msg, channel) => { messageTracker1.incrementReceived(channel); }); + await subscriber2.sSubscribe(channel, (_msg, channel) => { messageTracker2.incrementReceived(channel); }); + } + + const { controller, result } = + TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker1, // Use messageTracker1 for all publishing + ); + + // Wait for 10 seconds, while publishing messages + await setTimeout(10_000); + controller.abort(); + await result; + + for (const channel of CHANNELS) { + assert.strictEqual( + messageTracker1.getChannelStats(channel)?.received, + messageTracker1.getChannelStats(channel)?.sent, + ); + assert.strictEqual( + messageTracker2.getChannelStats(channel)?.received, + messageTracker1.getChannelStats(channel)?.sent, + ); + } + }); + + it("should resume publishing and receiving after failover", async () => { + for (const channel of CHANNELS) { + await subscriber1.sSubscribe(channel, (_msg, channel) => { messageTracker1.incrementReceived(channel); }); + await subscriber2.sSubscribe(channel, (_msg, channel) => { messageTracker2.incrementReceived(channel); }); + } + + // Start publishing messages + const { controller: publishAbort, result: publishResult } = + TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker1, // Use messageTracker1 for all publishing + ); + + // Trigger failover during publishing + const { action_id: failoverActionId } = + await faultInjectorClient.triggerAction({ + type: "failover", + parameters: { + bdb_id: config.clientConfig.bdbId.toString(), + cluster_index: 0, + }, + }); + + // Wait for failover to complete + await faultInjectorClient.waitForAction(failoverActionId); + + publishAbort.abort(); + await publishResult; + + for (const channel of CHANNELS) { + const sent = messageTracker1.getChannelStats(channel)!.sent; + const received1 = messageTracker1.getChannelStats(channel)!.received; + + const received2 = messageTracker2.getChannelStats(channel)!.received; + + assert.ok( + received1 <= sent, + `Channel ${channel}: received (${received1}) should be <= sent (${sent})`, + ); + assert.ok( + received2 <= sent, + `Channel ${channel}: received2 (${received2}) should be <= sent (${sent})`, + ); + } + + // Wait for 2 seconds before resuming publishing + await setTimeout(2_000); + + messageTracker1.reset(); + messageTracker2.reset(); + + const { + controller: afterFailoverController, + result: afterFailoverResult, + } = TestCommandRunner.publishMessagesUntilAbortSignal( + publisher, + CHANNELS, + messageTracker1, + ); + + await setTimeout(10_000); + afterFailoverController.abort(); + await afterFailoverResult; + + for (const channel of CHANNELS) { + const sent = messageTracker1.getChannelStats(channel)!.sent; + const received1 = messageTracker1.getChannelStats(channel)!.received; + const received2 = messageTracker2.getChannelStats(channel)!.received; + assert.ok(sent > 0, `Channel ${channel} should have sent messages`); + assert.ok( + received1 > 0, + `Channel ${channel} should have received messages by subscriber 1`, + ); + assert.ok( + received2 > 0, + `Channel ${channel} should have received messages by subscriber 2`, + ); + assert.strictEqual( + received1, + sent, + `Channel ${channel} received (${received1}) should equal sent (${sent}) once resumed after failover by subscriber 1`, + ); + assert.strictEqual( + received2, + sent, + `Channel ${channel} received (${received2}) should equal sent (${sent}) once resumed after failover by subscriber 2`, + ); + } + }); + }); +}); diff --git a/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts new file mode 100644 index 0000000000..7b1a217bbf --- /dev/null +++ b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts @@ -0,0 +1,90 @@ +import type { MessageTracker } from "./message-tracker"; +import { Cluster } from "./test.util"; +import { setTimeout } from "timers/promises"; + +/** + * Options for the `publishMessagesUntilAbortSignal` method + */ +interface PublishMessagesUntilAbortSignalOptions { + /** + * Number of messages to publish in each batch + */ + batchSize: number; + /** + * Timeout between batches in milliseconds + */ + timeoutMs: number; + /** + * Function that generates the message content to be published + */ + createMessage: () => string; +} + +/** + * Utility class for running test commands until a stop signal is received + */ +export class TestCommandRunner { + private static readonly defaultPublishOptions: PublishMessagesUntilAbortSignalOptions = + { + batchSize: 10, + timeoutMs: 10, + createMessage: () => Date.now().toString(), + }; + + /** + * Continuously publishes messages to the given Redis channels until aborted. + * + * @param {Redis|Cluster} client - Redis client or cluster instance used to publish messages. + * @param {string[]} channels - List of channel names to publish messages to. + * @param {MessageTracker} messageTracker - Tracks sent and failed message counts per channel. + * @param {Partial} [options] - Optional overrides for batch size, timeout, and message factory. + * @param {AbortController} [externalAbortController] - Optional external abort controller to control publishing lifecycle. + * @returns {{ controller: AbortController, result: Promise }} + * An object containing the abort controller and a promise that resolves when publishing stops. + */ + static publishMessagesUntilAbortSignal( + client: Cluster, + channels: string[], + messageTracker: MessageTracker, + options?: Partial, + externalAbortController?: AbortController, + ) { + const publishOptions = { + ...TestCommandRunner.defaultPublishOptions, + ...options, + }; + + const abortController = externalAbortController ?? new AbortController(); + + const result = async () => { + while (!abortController.signal.aborted) { + const batchPromises: Promise[] = []; + + for (let i = 0; i < publishOptions.batchSize; i++) { + for (const channel of channels) { + const message = publishOptions.createMessage(); + + const publishPromise = client + .sPublish(channel, message) + .then(() => { + messageTracker.incrementSent(channel); + }) + .catch(() => { + messageTracker.incrementFailed(channel); + }); + + batchPromises.push(publishPromise); + } + } + + await Promise.all(batchPromises); + await setTimeout(publishOptions.timeoutMs); + } + }; + + return { + controller: abortController, + result: result(), + }; + } +} diff --git a/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/message-tracker.ts b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/message-tracker.ts new file mode 100644 index 0000000000..6393356c8c --- /dev/null +++ b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/message-tracker.ts @@ -0,0 +1,52 @@ +export interface MessageStats { + sent: number; + received: number; + failed: number; +} + +export class MessageTracker { + private stats: Record = {}; + + constructor(channels: string[]) { + this.initializeChannels(channels); + } + + private initializeChannels(channels: string[]): void { + this.stats = channels.reduce((acc, channel) => { + acc[channel] = { sent: 0, received: 0, failed: 0 }; + return acc; + }, {} as Record); + } + + reset(): void { + Object.keys(this.stats).forEach((channel) => { + this.stats[channel] = { sent: 0, received: 0, failed: 0 }; + }); + } + + incrementSent(channel: string): void { + if (this.stats[channel]) { + this.stats[channel].sent++; + } + } + + incrementReceived(channel: string): void { + if (this.stats[channel]) { + this.stats[channel].received++; + } + } + + incrementFailed(channel: string): void { + if (this.stats[channel]) { + this.stats[channel].failed++; + } + } + + getChannelStats(channel: string): MessageStats | undefined { + return this.stats[channel]; + } + + getAllStats(): Record { + return this.stats; + } +} diff --git a/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/test.util.ts b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/test.util.ts new file mode 100644 index 0000000000..9ef683e5e5 --- /dev/null +++ b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/test.util.ts @@ -0,0 +1,211 @@ +import { readFileSync } from "fs"; +import RedisCluster, { + RedisClusterOptions, +} from "../../../../cluster"; + +interface DatabaseEndpoint { + addr: string[]; + addr_type: string; + dns_name: string; + oss_cluster_api_preferred_endpoint_type: string; + oss_cluster_api_preferred_ip_type: string; + port: number; + proxy_policy: string; + uid: string; +} + +interface DatabaseConfig { + bdb_id: number; + username: string; + password: string; + tls: boolean; + raw_endpoints: DatabaseEndpoint[]; + endpoints: string[]; +} + +type DatabasesConfig = Record; + +interface EnvConfig { + redisEndpointsConfigPath: string; + faultInjectorUrl: string; +} + +export interface RedisConnectionConfig { + host: string; + port: number; + username: string; + password: string; + tls: boolean; + bdbId: number; +} + +export interface TestConfig { + clientConfig: RedisConnectionConfig; + faultInjectorUrl: string; +} + +/** + * Reads environment variables required for the test scenario + * @returns Environment configuration object + * @throws Error if required environment variables are not set + */ +const getEnvConfig = (): EnvConfig => { + if (!process.env["REDIS_ENDPOINTS_CONFIG_PATH"]) { + throw new Error( + "REDIS_ENDPOINTS_CONFIG_PATH environment variable must be set", + ); + } + + if (!process.env["RE_FAULT_INJECTOR_URL"]) { + throw new Error("RE_FAULT_INJECTOR_URL environment variable must be set"); + } + + return { + redisEndpointsConfigPath: process.env["REDIS_ENDPOINTS_CONFIG_PATH"], + faultInjectorUrl: process.env["RE_FAULT_INJECTOR_URL"], + }; +}; + +/** + * Reads database configuration from a file + * @param filePath - The path to the database configuration file + * @returns Parsed database configuration object + * @throws Error if file doesn't exist or JSON is invalid + */ +const getDatabaseConfigFromEnv = (filePath: string): DatabasesConfig => { + try { + const fileContent = readFileSync(filePath, "utf8"); + return JSON.parse(fileContent) as DatabasesConfig; + } catch (_error) { + throw new Error(`Failed to read or parse database config from ${filePath}`); + } +}; + +/** + * Gets Redis connection parameters for a specific database + * @param databasesConfig - The parsed database configuration object + * @param databaseName - Optional name of the database to retrieve (defaults to the first one) + * @returns Redis connection configuration with host, port, username, password, and tls + * @throws Error if the specified database is not found in the configuration + */ +const getDatabaseConfig = ( + databasesConfig: DatabasesConfig, + databaseName?: string, +): RedisConnectionConfig => { + const dbConfig = databaseName + ? databasesConfig[databaseName] + : Object.values(databasesConfig)[0]; + + if (!dbConfig) { + throw new Error( + `Database ${databaseName || ""} not found in configuration`, + ); + } + + const endpoint = dbConfig.raw_endpoints[0]; // Use the first endpoint + + if (!endpoint) { + throw new Error(`No endpoints found for database ${databaseName}`); + } + + return { + host: endpoint.dns_name, + port: endpoint.port, + username: dbConfig.username, + password: dbConfig.password, + tls: dbConfig.tls, + bdbId: dbConfig.bdb_id, + }; +}; + +/** + * Gets Redis connection parameters for a specific database + * @returns Redis client config and fault injector URL + * @throws Error if required environment variables are not set or if database config is invalid + */ +export const getConfig = (): TestConfig => { + const envConfig = getEnvConfig(); + const redisConfig = getDatabaseConfigFromEnv( + envConfig.redisEndpointsConfigPath, + ); + + return { + clientConfig: getDatabaseConfig(redisConfig), + faultInjectorUrl: envConfig.faultInjectorUrl, + }; +}; + +/** + * Creates a test cluster client with the provided configuration, connects it and attaches an error handler listener + * @param clientConfig - The Redis connection configuration + * @param options - Optional cluster options + * @returns The created Redis Cluster client + */ +export const createClusterTestClient = ( + clientConfig: RedisConnectionConfig, + options: Partial = {}, +) => { + return RedisCluster.create({ + ...options, + rootNodes: [ + { + socket: { + host: clientConfig.host, + port: clientConfig.port, + }, + }, + ], + defaults: { + credentialsProvider: { + type: "async-credentials-provider", + credentials: async () => ({ + username: clientConfig.username, + password: clientConfig.password, + }), + }, + }, + }); +}; + +export type Cluster = ReturnType; + +/** + * A list of example Redis Cluster channel keys covering all slot ranges. + */ +export const CHANNELS = [ + "channel:11kv:1000", + "channel:osy:2000", + "channel:jn6:3000", + "channel:l00:4000", + "channel:4ez:5000", + "channel:4ek:6000", + "channel:9vn:7000", + "channel:dw1:8000", + "channel:9zi:9000", + "channel:4vl:10000", + "channel:utl:11000", + "channel:lyo:12000", + "channel:jzn:13000", + "channel:14uc:14000", + "channel:mz:15000", + "channel:d0v:16000", +]; + +export const CHANNELS_BY_SLOT = { + 1000: "channel:11kv:1000", + 2000: "channel:osy:2000", + 3000: "channel:jn6:3000", + 4000: "channel:l00:4000", + 5000: "channel:4ez:5000", + 6000: "channel:4ek:6000", + 7000: "channel:9vn:7000", + 8000: "channel:dw1:8000", + 9000: "channel:9zi:9000", + 10000: "channel:4vl:10000", + 11000: "channel:utl:11000", + 12000: "channel:lyo:12000", + 13000: "channel:jzn:13000", + 14000: "channel:14uc:14000", + 15000: "channel:mz:15000", + 16000: "channel:d0v:16000", +} as const;