1
0
mirror of https://github.com/redis/node-redis.git synced 2025-12-11 09:22:35 +03:00

chore: proxy improvements (#3121)

* introduce global interceptors

* move proxy stuff to new folder

* implement resp framer

* properly handle request/response and push

* add global interceptor
This commit is contained in:
Nikolay Karadzhov
2025-11-03 11:08:17 +02:00
committed by GitHub
parent 96a8a847f6
commit 130e88d45c
7 changed files with 1347 additions and 201 deletions

View File

@@ -26,7 +26,7 @@ import { hideBin } from 'yargs/helpers';
import * as fs from 'node:fs';
import * as os from 'node:os';
import * as path from 'node:path';
import { RedisProxy, getFreePortNumber } from './redis-proxy';
import { RedisProxy, getFreePortNumber } from './proxy/redis-proxy';
interface TestUtilsConfig {
/**

View File

@@ -0,0 +1,315 @@
import { strict as assert } from 'node:assert';
import { Buffer } from 'node:buffer';
import { testUtils, GLOBAL } from '../test-utils';
import { InterceptorDescription, RedisProxy } from './redis-proxy';
import type { RedisClientType } from '@redis/client/lib/client/index.js';
describe('RedisSocketProxy', function () {
testUtils.testWithClient('basic proxy functionality', async (client: RedisClientType<any, any, any, any, any>) => {
const socketOptions = client?.options?.socket;
//@ts-ignore
assert(socketOptions?.port, 'Test requires a TCP connection to Redis');
const proxyPort = 50000 + Math.floor(Math.random() * 10000);
const proxy = new RedisProxy({
listenHost: '127.0.0.1',
listenPort: proxyPort,
//@ts-ignore
targetPort: socketOptions.port,
//@ts-ignore
targetHost: socketOptions.host || '127.0.0.1',
enableLogging: true
});
const proxyEvents = {
connections: [] as any[],
dataTransfers: [] as any[]
};
proxy.on('connection', (connectionInfo) => {
proxyEvents.connections.push(connectionInfo);
});
proxy.on('data', (connectionId, direction, data) => {
proxyEvents.dataTransfers.push({ connectionId, direction, dataLength: data.length });
});
try {
await proxy.start();
const proxyClient = client.duplicate({
socket: {
port: proxyPort,
host: '127.0.0.1'
},
});
await proxyClient.connect();
const stats = proxy.getStats();
assert.equal(stats.activeConnections, 1, 'Should have one active connection');
assert.equal(proxyEvents.connections.length, 1, 'Should have recorded one connection event');
const pingResult = await proxyClient.ping();
assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy');
const clientToServerTransfers = proxyEvents.dataTransfers.filter(t => t.direction === 'client->server');
const serverToClientTransfers = proxyEvents.dataTransfers.filter(t => t.direction === 'server->client');
assert(clientToServerTransfers.length > 0, 'Should have client->server data transfers');
assert(serverToClientTransfers.length > 0, 'Should have server->client data transfers');
const testKey = `test:proxy:${Date.now()}`;
const testValue = 'proxy-test-value';
await proxyClient.set(testKey, testValue);
const retrievedValue = await proxyClient.get(testKey);
assert.equal(retrievedValue, testValue, 'Should be able to set and get values through proxy');
proxyClient.destroy();
} finally {
await proxy.stop();
}
}, GLOBAL.SERVERS.OPEN_RESP_3);
testUtils.testWithProxiedClient('custom message injection via proxy client',
async (proxiedClient: RedisClientType<any, any, any, any, any>, proxy: RedisProxy) => {
const customMessageTransfers: any[] = [];
proxy.on('data', (connectionId, direction, data) => {
if (direction === 'server->client') {
customMessageTransfers.push({ connectionId, dataLength: data.length, data });
}
});
const stats = proxy.getStats();
assert.equal(stats.activeConnections, 1, 'Should have one active connection');
// Send a resp3 push
const customMessage = Buffer.from('>4\r\n$6\r\nMOVING\r\n:1\r\n:2\r\n$6\r\nhost:3\r\n');
const sendResults = proxy.sendToAllClients(customMessage);
assert.equal(sendResults.length, 1, 'Should send to one client');
assert.equal(sendResults[0].success, true, 'Custom message send should succeed');
const customMessageFound = customMessageTransfers.find(transfer =>
transfer.dataLength === customMessage.length
);
assert(customMessageFound, 'Should have recorded the custom message transfer');
assert.equal(customMessageFound.dataLength, customMessage.length,
'Custom message length should match');
const pingResult = await proxiedClient.ping();
assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy');
}, GLOBAL.SERVERS.OPEN_RESP_3);
describe("Middleware", () => {
testUtils.testWithProxiedClient(
"Modify request/response via middleware",
async (
proxiedClient: RedisClientType<any, any, any, any, any>,
proxy: RedisProxy,
) => {
// Intercept PING commands and modify the response
const pingInterceptor: InterceptorDescription = {
name: `ping`,
fn: async (data, next) => {
if (data.includes('PING')) {
return Buffer.from("+PINGINTERCEPTED\r\n");
}
return next(data);
}
};
// Only intercept GET responses and double numeric values
// Does not modify other commands or non-numeric GET responses
const doubleNumberGetInterceptor: InterceptorDescription = {
name: `double-number-get`,
fn: async (data, next) => {
const response = await next(data);
// Not a GET command, return original response
if (!data.includes("GET")) return response;
const value = (response.toString().split("\r\n"))[1];
const number = Number(value);
// Not a number, return original response
if(isNaN(number)) return response;
const doubled = String(number * 2);
return Buffer.from(`$${doubled.length}\r\n${doubled}\r\n`);
}
};
proxy.setGlobalInterceptors([ pingInterceptor, doubleNumberGetInterceptor ])
const pingResponse = await proxiedClient.ping();
assert.equal(pingResponse, 'PINGINTERCEPTED', 'Response should be modified by middleware');
await proxiedClient.set('foo', 1);
const getResponse1 = await proxiedClient.get('foo');
assert.equal(getResponse1, '2', 'GET response should be doubled for numbers by middleware');
await proxiedClient.set('bar', 'Hi');
const getResponse2 = await proxiedClient.get('bar');
assert.equal(getResponse2, 'Hi', 'GET response should not be modified for strings by middleware');
await proxiedClient.hSet('baz', 'foo', 'dictvalue');
const hgetResponse = await proxiedClient.hGet('baz', 'foo');
assert.equal(hgetResponse, 'dictvalue', 'HGET response should not be modified by middleware');
},
GLOBAL.SERVERS.OPEN_RESP_3,
);
testUtils.testWithProxiedClient(
"Stats reflect middleware activity",
async (
proxiedClient: RedisClientType<any, any, any, any, any>,
proxy: RedisProxy,
) => {
const PING = `ping`;
const SKIPPED = `skipped`;
proxy.setGlobalInterceptors([
{
name: PING,
matchLimit: 3,
fn: async (data, next, state) => {
state.invokeCount++;
if(state.matchCount === state.matchLimit) return next(data);
if (data.includes("PING")) {
state.matchCount++;
return Buffer.from("+PINGINTERCEPTED\r\n");
}
return next(data);
},
},
{
name: SKIPPED,
fn: async (data, next, state) => {
state.invokeCount++;
state.matchCount++;
// This interceptor does not match anything
return next(data);
},
},
]);
await proxiedClient.ping();
await proxiedClient.ping();
await proxiedClient.ping();
let stats = proxy.getStats();
let pingInterceptor = stats.globalInterceptors.find(
(i) => i.name === PING,
);
assert.ok(pingInterceptor, "PING interceptor stats should be present");
assert.equal(pingInterceptor.invokeCount, 3);
assert.equal(pingInterceptor.matchCount, 3);
let skipInterceptor = stats.globalInterceptors.find(
(i) => i.name === SKIPPED,
);
assert.ok(skipInterceptor, "SKIPPED interceptor stats should be present");
assert.equal(skipInterceptor.invokeCount, 0);
assert.equal(skipInterceptor.matchCount, 0);
await proxiedClient.set("foo", "bar");
await proxiedClient.get("foo");
stats = proxy.getStats();
pingInterceptor = stats.globalInterceptors.find(
(i) => i.name === PING,
);
assert.ok(pingInterceptor, "PING interceptor stats should be present");
assert.equal(pingInterceptor.invokeCount, 5);
assert.equal(pingInterceptor.matchCount, 3);
await proxiedClient.ping();
stats = proxy.getStats();
pingInterceptor = stats.globalInterceptors.find(
(i) => i.name === PING,
);
assert.ok(pingInterceptor, "PING interceptor stats should be present");
assert.equal(pingInterceptor.invokeCount, 6);
assert.equal(pingInterceptor.matchCount, 3, 'Should not match more than limit');
skipInterceptor = stats.globalInterceptors.find(
(i) => i.name === SKIPPED,
);
assert.ok(skipInterceptor, "PING interceptor stats should be present");
assert.equal(skipInterceptor.invokeCount, 3);
assert.equal(skipInterceptor.matchCount, 3);
},
GLOBAL.SERVERS.OPEN_RESP_3,
);
testUtils.testWithProxiedClient(
"Middleware is given exactly one RESP message at a time",
async (
proxiedClient: RedisClientType<any, any, any, any, any>,
proxy: RedisProxy,
) => {
proxy.setGlobalInterceptors([
{
name: `ping`,
fn: async (data, next, state) => {
state.invokeCount++;
if (data.equals(Buffer.from("*1\r\n$4\r\nPING\r\n"))) {
state.matchCount++;
}
return next(data);
},
},
]);
await Promise.all([proxiedClient.ping(), proxiedClient.ping()]);
const stats = proxy.getStats();
const pingInterceptor = stats.globalInterceptors.find(
(i) => i.name === `ping`,
);
assert.ok(pingInterceptor, "PING interceptor stats should be present");
assert.equal(pingInterceptor.invokeCount, 2);
assert.equal(pingInterceptor.matchCount, 2);
},
GLOBAL.SERVERS.OPEN_RESP_3,
);
testUtils.testWithProxiedClient(
"Proxy passes through push messages",
async (
proxiedClient: RedisClientType<any, any, any, any, any>,
proxy: RedisProxy,
) => {
let resolve: (value: string) => void;
const promise = new Promise((rs) => { resolve = rs; });
await proxiedClient.subscribe("test-push-channel", (message) => {
resolve(message);
});
await proxiedClient.publish("test-push-channel", "hello");
const result = await promise;
assert.equal(result, "hello", "Should receive push message through proxy");
},
{
...GLOBAL.SERVERS.OPEN_RESP_3,
clientOptions: {
maintNotifications: 'disabled',
disableClientInfo: true,
RESP: 3
}
},
);
});
});

View File

@@ -1,5 +1,7 @@
import * as net from 'net';
import { EventEmitter } from 'events';
import RespFramer from './resp-framer';
import RespQueue from './resp-queue';
interface ProxyConfig {
readonly listenPort: number;
@@ -10,17 +12,21 @@ interface ProxyConfig {
readonly enableLogging?: boolean;
}
interface ConnectionInfo {
interface ConnectionInfoCommon {
readonly id: string;
readonly clientAddress: string;
readonly clientPort: number;
readonly connectedAt: Date;
}
interface ActiveConnection extends ConnectionInfo {
interface ConnectionInfo extends ConnectionInfoCommon {
readonly interceptors: InterceptorState[];
}
interface ActiveConnection extends ConnectionInfoCommon {
readonly clientSocket: net.Socket;
readonly serverSocket: net.Socket;
inflightRequestsCount: number
interceptors: Interceptor[];
}
type SendResult =
@@ -33,6 +39,7 @@ interface ProxyStats {
readonly activeConnections: number;
readonly totalConnections: number;
readonly connections: readonly ConnectionInfo[];
readonly globalInterceptors: InterceptorState[];
}
interface ProxyEvents {
@@ -50,16 +57,35 @@ interface ProxyEvents {
'close': () => void;
}
export type Interceptor = (data: Buffer) => Promise<Buffer>;
export type InterceptorFunction = (data: Buffer, next: Interceptor) => Promise<Buffer>;
type InterceptorInitializer = (init: Interceptor) => Interceptor;
export type Next = (data: Buffer) => Promise<Buffer>;
export type InterceptorFunction = (data: Buffer, next: Next, state: InterceptorState) => Promise<Buffer>;
export interface InterceptorDescription {
name: string;
matchLimit?: number;
fn: InterceptorFunction;
}
export interface InterceptorState {
name: string;
matchLimit?: number;
invokeCount: number;
matchCount: number;
}
interface Interceptor {
name: string;
state: InterceptorState;
fn: InterceptorFunction;
}
export class RedisProxy extends EventEmitter {
private readonly server: net.Server;
public readonly config: Required<ProxyConfig>;
private readonly connections: Map<string, ActiveConnection>;
private isRunning: boolean;
private interceptorInitializer: InterceptorInitializer = (init) => init;
private globalInterceptors: Interceptor[] = [];
constructor(config: ProxyConfig) {
super();
@@ -119,11 +145,32 @@ export class RedisProxy extends EventEmitter {
});
}
public setInterceptors(interceptors: Array<InterceptorFunction>) {
this.interceptorInitializer = (init) => interceptors.reduceRight<Interceptor>(
(next, mw) => (data) => mw(data, next),
init
);
private makeInterceptor(description: InterceptorDescription): Interceptor {
const { name, fn, matchLimit } = description;
return {
name,
fn,
state: {
name,
matchCount: 0,
invokeCount: 0,
matchLimit,
},
};
}
public setGlobalInterceptors(
interceptorDescriptions: Array<InterceptorDescription>,
) {
const interceptors: Interceptor[] = interceptorDescriptions.map(this.makeInterceptor);
this.globalInterceptors = interceptors;
}
public addGlobalInterceptor(
interceptorDescription: InterceptorDescription,
) {
const interceptor = this.makeInterceptor(interceptorDescription);
this.globalInterceptors = [interceptor, ...this.globalInterceptors.filter(i => i.name !== interceptor.name)];
}
public getStats(): ProxyStats {
@@ -132,12 +179,14 @@ export class RedisProxy extends EventEmitter {
return {
activeConnections: connections.length,
totalConnections: connections.length,
globalInterceptors: this.globalInterceptors.map(i => i.state),
connections: connections.map((conn) => ({
id: conn.id,
clientAddress: conn.clientAddress,
clientPort: conn.clientPort,
connectedAt: conn.connectedAt,
}))
interceptors: conn.interceptors.map(i => i.state)
})),
};
}
@@ -246,7 +295,7 @@ export class RedisProxy extends EventEmitter {
connectedAt: new Date(),
clientSocket,
serverSocket,
inflightRequestsCount: 0
interceptors: [],
};
this.connections.set(connectionId, connectionInfo);
@@ -259,33 +308,39 @@ export class RedisProxy extends EventEmitter {
this.emit('connection', connectionInfo);
});
clientSocket.on('data', async (data) => {
this.emit('data', connectionId, 'client->server', data);
/**
*
* client -> clientSocket -> clientRespFramer -> interceptors -> queue -> serverSocket -> server
* client <- clientSocket <- interceptors <- response | queue <- serverRespFramer <- serverSocket <- server
* client <- clientSocket <- push |
*/
const clientRespFramer = new RespFramer();
const respQueue = new RespQueue(serverSocket);
connectionInfo.inflightRequestsCount++;
clientRespFramer.on('message', async (data) => {
// next1 -> next2 -> ... -> last -> server
// next1 <- next2 <- ... <- last <- server
const last = (data: Buffer): Promise<Buffer> => {
return new Promise((resolve, reject) => {
serverSocket.write(data);
serverSocket.once('data', (data) => {
connectionInfo.inflightRequestsCount--;
assert(connectionInfo.inflightRequestsCount >= 0, `inflightRequestsCount for connection ${connectionId} went below zero`);
this.emit('data', connectionId, 'server->client', data);
resolve(data);
});
serverSocket.once('error', reject);
});
const last = async (data: Buffer): Promise<Buffer> => {
this.emit('data', connectionId, 'client->server', data);
const response = await respQueue.request(data);
return response;
};
const interceptorChain = this.interceptorInitializer(last);
const interceptorChain = connectionInfo.interceptors.concat(this.globalInterceptors).reduceRight<Next>(
(next, interceptor) => (data) =>
interceptor.fn(data, next, interceptor.state),
last,
);
const response = await interceptorChain(data);
this.emit('data', connectionId, 'server->client', response);
clientSocket.write(response);
});
serverSocket.on('data', (data) => {
if (connectionInfo.inflightRequestsCount > 0) return;
clientSocket.on('data', data => clientRespFramer.write(data));
respQueue.on('push', (data) => {
this.emit('data', connectionId, 'server->client', data);
clientSocket.write(data);
});
@@ -310,7 +365,6 @@ export class RedisProxy extends EventEmitter {
});
serverSocket.on('error', (error) => {
if (connectionInfo.inflightRequestsCount > 0) return;
this.log(`Server error for connection ${connectionId}: ${error.message}`);
this.emit('error', error, connectionId);
clientSocket.destroy();
@@ -344,7 +398,6 @@ export class RedisProxy extends EventEmitter {
}
}
import { createServer } from 'net';
import assert from 'node:assert';
export function getFreePortNumber(): Promise<number> {
return new Promise((resolve, reject) => {

View File

@@ -0,0 +1,735 @@
import { strict as assert } from 'node:assert';
import RespFramer from './resp-framer';
describe('RespFramer - RESP2', () => {
it('should emit a simple string message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('+OK\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit an error message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('-ERR unknown command\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit an integer message', async () => {
const framer = new RespFramer();
const expected = Buffer.from(':1000\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a bulk string message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('$6\r\nfoobar\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a null bulk string', async () => {
const framer = new RespFramer();
const expected = Buffer.from('$-1\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit an array message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a null array', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*-1\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit nested arrays', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle multiple complete messages', async () => {
const framer = new RespFramer();
const messages = [
Buffer.from('+OK\r\n'),
Buffer.from(':42\r\n'),
Buffer.from('$3\r\nfoo\r\n')
];
const combined = Buffer.concat(messages);
const received: Buffer[] = [];
const messagesPromise = new Promise<Buffer[]>((resolve) => {
framer.on('message', (message) => {
received.push(message);
if (received.length === 3) {
resolve(received);
}
});
});
framer.write(combined);
const result = await messagesPromise;
assert.equal(result.length, messages.length);
messages.forEach((expected, i) => {
assert.deepEqual(result[i], expected);
});
});
it('should handle partial messages across multiple writes', async () => {
const framer = new RespFramer();
const fullMessage = Buffer.from('$6\r\nfoobar\r\n');
const part1 = fullMessage.subarray(0, 5);
const part2 = fullMessage.subarray(5);
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(part1);
framer.write(part2);
const message = await messagePromise;
assert.deepEqual(message, fullMessage);
});
it('should handle array split across multiple writes', async () => {
const framer = new RespFramer();
const fullMessage = Buffer.from('*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n');
const part1 = fullMessage.subarray(0, 10);
const part2 = fullMessage.subarray(10);
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(part1);
framer.write(part2);
const message = await messagePromise;
assert.deepEqual(message, fullMessage);
});
it('should handle empty array', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*0\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle empty bulk string', async () => {
const framer = new RespFramer();
const expected = Buffer.from('$0\r\n\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle mixed message types in sequence', async () => {
const framer = new RespFramer();
const messages = [
Buffer.from('+PONG\r\n'),
Buffer.from('$3\r\nGET\r\n'),
Buffer.from('*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n'),
Buffer.from(':123\r\n'),
Buffer.from('-Error\r\n')
];
const received: Buffer[] = [];
const messagesPromise = new Promise<Buffer[]>((resolve) => {
framer.on('message', (message) => {
received.push(message);
if (received.length === messages.length) {
resolve(received);
}
});
});
messages.forEach(msg => framer.write(msg));
const result = await messagesPromise;
assert.equal(result.length, messages.length);
messages.forEach((expected, i) => {
assert.deepEqual(result[i], expected);
});
});
it('should handle bulk string containing \\r\\n in the data', async () => {
const framer = new RespFramer();
const expected = Buffer.from('$12\r\nhello\r\nworld\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle bulk string with binary data including null bytes', async () => {
const framer = new RespFramer();
const binaryData = Buffer.from([0x00, 0x01, 0x02, 0xff, 0xfe]);
const expected = Buffer.concat([
Buffer.from('$5\r\n'),
binaryData,
Buffer.from('\r\n')
]);
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle array with bulk strings containing \\r\\n', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*2\r\n$5\r\nfoo\r\n\r\n$5\r\nbar\r\n\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
});
describe('RespFramer - RESP3', () => {
it('should emit a null message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('_\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a boolean true message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('#t\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a boolean false message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('#f\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a double message', async () => {
const framer = new RespFramer();
const expected = Buffer.from(',3.14159\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a double infinity message', async () => {
const framer = new RespFramer();
const expected = Buffer.from(',inf\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a double negative infinity message', async () => {
const framer = new RespFramer();
const expected = Buffer.from(',-inf\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a big number message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('(3492890328409238509324850943850943825024385\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a bulk error message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('!21\r\nSYNTAX invalid syntax\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a verbatim string message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('=15\r\ntxt:Some string\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a map message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a set message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('~3\r\n+apple\r\n+banana\r\n+cherry\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit a push message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('>3\r\n+pubsub\r\n+message\r\n+channel\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should emit an attribute message', async () => {
const framer = new RespFramer();
const expected = Buffer.from('|1\r\n+key-popularity\r\n%2\r\n$1\r\na\r\n,0.1923\r\n$1\r\nb\r\n,0.0012\r\n*2\r\n:2039123\r\n:9543892\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle nested RESP3 structures', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%2\r\n$4\r\nname\r\n$5\r\nAlice\r\n$3\r\nage\r\n:30\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle empty map', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%0\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle empty set', async () => {
const framer = new RespFramer();
const expected = Buffer.from('~0\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle map with nested arrays', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%1\r\n$4\r\ndata\r\n*2\r\n:1\r\n:2\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle set with mixed types', async () => {
const framer = new RespFramer();
const expected = Buffer.from('~4\r\n+string\r\n:42\r\n#t\r\n_\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle RESP3 split across multiple writes', async () => {
const framer = new RespFramer();
const fullMessage = Buffer.from('%2\r\n+key1\r\n:100\r\n+key2\r\n:200\r\n');
const part1 = fullMessage.subarray(0, 10);
const part2 = fullMessage.subarray(10);
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(part1);
framer.write(part2);
const message = await messagePromise;
assert.deepEqual(message, fullMessage);
});
it('should handle mixed RESP2 and RESP3 messages', async () => {
const framer = new RespFramer();
const messages = [
Buffer.from('*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n'),
Buffer.from('%1\r\n+result\r\n$5\r\nvalue\r\n'),
Buffer.from('#t\r\n'),
Buffer.from('_\r\n'),
Buffer.from(',3.14\r\n')
];
const received: Buffer[] = [];
const messagesPromise = new Promise<Buffer[]>((resolve) => {
framer.on('message', (message) => {
received.push(message);
if (received.length === messages.length) {
resolve(received);
}
});
});
messages.forEach(msg => framer.write(msg));
const result = await messagesPromise;
assert.equal(result.length, messages.length);
messages.forEach((expected, i) => {
assert.deepEqual(result[i], expected);
});
});
it('should handle array with attribute metadata', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*3\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle null map', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%-1\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle null set', async () => {
const framer = new RespFramer();
const expected = Buffer.from('~-1\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle null push', async () => {
const framer = new RespFramer();
const expected = Buffer.from('>-1\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle attribute with empty metadata', async () => {
const framer = new RespFramer();
const expected = Buffer.from('|0\r\n:42\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle blob error with binary data', async () => {
const framer = new RespFramer();
const binaryData = Buffer.from([0x00, 0x01, 0x02, 0xff, 0xfe]);
const expected = Buffer.concat([
Buffer.from('!5\r\n'),
binaryData,
Buffer.from('\r\n')
]);
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle verbatim string with different encoding', async () => {
const framer = new RespFramer();
const expected = Buffer.from('=17\r\nmkd:# Hello World\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle double NaN', async () => {
const framer = new RespFramer();
const expected = Buffer.from(',nan\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle deeply nested structures', async () => {
const framer = new RespFramer();
const expected = Buffer.from('*2\r\n%1\r\n+key\r\n*2\r\n:1\r\n:2\r\n~2\r\n+a\r\n+b\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle push with nested map', async () => {
const framer = new RespFramer();
const expected = Buffer.from('>2\r\n+pubsub\r\n%1\r\n+channel\r\n+news\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle attribute split across multiple writes', async () => {
const framer = new RespFramer();
const fullMessage = Buffer.from('|1\r\n+ttl\r\n:3600\r\n+value\r\n');
const part1 = fullMessage.subarray(0, 10);
const part2 = fullMessage.subarray(10);
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(part1);
framer.write(part2);
const message = await messagePromise;
assert.deepEqual(message, fullMessage);
});
it('should handle map with null values', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%2\r\n+key1\r\n_\r\n+key2\r\n$-1\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle nested maps', async () => {
const framer = new RespFramer();
const expected = Buffer.from('%1\r\n+outer\r\n%2\r\n+inner1\r\n:1\r\n+inner2\r\n:2\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
it('should handle set containing arrays', async () => {
const framer = new RespFramer();
const expected = Buffer.from('~2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n');
const messagePromise = new Promise<Buffer>((resolve) => {
framer.once('message', resolve);
});
framer.write(expected);
const message = await messagePromise;
assert.deepEqual(message, expected);
});
});

View File

@@ -0,0 +1,167 @@
// RespFramer: Frames raw Buffer data into complete RESP messages
// Accumulates incoming bytes and emits each complete RESP message as a separate Buffer
import EventEmitter from "node:events";
export interface RespFramerEvents {
message: (data: Buffer) => void;
push: (data: Buffer) => void;
}
export default class RespFramer extends EventEmitter {
private buffer: Buffer;
private offset: number;
constructor() {
super();
this.buffer = Buffer.alloc(0);
this.offset = 0;
}
public write(data: Buffer) {
this.buffer = Buffer.concat([this.buffer, data]);
while (this.offset < this.buffer.length) {
const messageEnd = this.findMessageEnd(this.buffer, this.offset);
if (messageEnd === -1) {
break; // Incomplete message
}
const message = this.buffer.subarray(this.offset, messageEnd);
this.emit("message", message);
this.offset = messageEnd;
}
// Remove processed data from the buffer
if (this.offset > 0) {
this.buffer = this.buffer.subarray(this.offset);
this.offset = 0;
}
}
private findMessageEnd(buffer: Buffer, start: number): number {
if (start >= buffer.length) {
return -1;
}
const prefix = String.fromCharCode(buffer[start]);
switch (prefix) {
case "+": // Simple String
case "-": // Error
case ":": // Integer
case "_": // Null
case "#": // Boolean
case ",": // Double
case "(": // Big Number
return this.findLineEnd(buffer, start);
case "$": // Bulk String
case "!": // Bulk Error
case "=": // Verbatim String
return this.findBulkStringEnd(buffer, start);
case "*": // Array
return this.findArrayEnd(buffer, start);
case "%": // Map
return this.findMapEnd(buffer, start);
case "~": // Set
case ">": // Push
return this.findArrayEnd(buffer, start);
case "|": // Attribute
return this.findAttributeEnd(buffer, start);
default:
return -1; // Unknown prefix
}
}
private findArrayEnd(buffer: Buffer, start: number): number {
const result = this.readLength(buffer, start);
if (!result) {
return -1;
}
const { length, lineEnd } = result;
if (length === -1) {
return lineEnd;
}
let currentOffset = lineEnd;
for (let i = 0; i < length; i++) {
const elementEnd = this.findMessageEnd(buffer, currentOffset);
if (elementEnd === -1) {
return -1;
}
currentOffset = elementEnd;
}
return currentOffset;
}
private findBulkStringEnd(buffer: Buffer, start: number): number {
const result = this.readLength(buffer, start);
if (!result) {
return -1;
}
const { length, lineEnd } = result;
if (length === -1) {
return lineEnd;
}
const totalLength = lineEnd + length + 2;
return totalLength <= buffer.length ? totalLength : -1;
}
private findMapEnd(buffer: Buffer, start: number): number {
const result = this.readLength(buffer, start);
if (!result) {
return -1;
}
const { length, lineEnd } = result;
if (length === -1) {
return lineEnd;
}
let currentOffset = lineEnd;
for (let i = 0; i < length * 2; i++) {
const elementEnd = this.findMessageEnd(buffer, currentOffset);
if (elementEnd === -1) {
return -1;
}
currentOffset = elementEnd;
}
return currentOffset;
}
private findAttributeEnd(buffer: Buffer, start: number): number {
const result = this.readLength(buffer, start);
if (!result) {
return -1;
}
const { length, lineEnd } = result;
let currentOffset = lineEnd;
for (let i = 0; i < length * 2; i++) {
const elementEnd = this.findMessageEnd(buffer, currentOffset);
if (elementEnd === -1) {
return -1;
}
currentOffset = elementEnd;
}
const valueEnd = this.findMessageEnd(buffer, currentOffset);
if (valueEnd === -1) {
return -1;
}
return valueEnd;
}
private findLineEnd(buffer: Buffer, start: number): number {
const end = buffer.indexOf("\r\n", start);
return end !== -1 ? end + 2 : -1;
}
private readLength(
buffer: Buffer,
start: number,
): { length: number; lineEnd: number } | null {
const lineEnd = this.findLineEnd(buffer, start);
if (lineEnd === -1) {
return null;
}
const lengthLine = buffer.subarray(start + 1, lineEnd - 2).toString();
const length = parseInt(lengthLine, 10);
if (isNaN(length)) {
return null;
}
return { length, lineEnd };
}
}

View File

@@ -0,0 +1,43 @@
import { EventEmitter } from "node:events";
import RespFramer from "./resp-framer";
import { Socket } from "node:net";
interface Request {
resolve: (data: Buffer) => void;
reject: (reason: any) => void;
}
export default class RespQueue extends EventEmitter {
queue: Request[] = [];
respFramer: RespFramer = new RespFramer();
constructor(private serverSocket: Socket) {
super();
this.respFramer.on("message", (msg) => this.handleMessage(msg));
this.serverSocket.on("data", (data) => this.respFramer.write(data));
}
handleMessage(data: Buffer) {
const request = this.queue.shift();
if (request) {
request.resolve(data);
} else {
this.emit("push", data);
}
}
request(data: Buffer): Promise<Buffer> {
let resolve: (data: Buffer) => void;
let reject: (reason: any) => void;
const promise = new Promise<Buffer>((rs, rj) => {
resolve = rs;
reject = rj;
});
//@ts-ignore
this.queue.push({ resolve, reject });
this.serverSocket.write(data);
return promise;
}
}

View File

@@ -1,167 +0,0 @@
import { strict as assert } from 'node:assert';
import { Buffer } from 'node:buffer';
import { testUtils, GLOBAL } from './test-utils';
import { InterceptorFunction, RedisProxy } from './redis-proxy';
import type { RedisClientType } from '@redis/client/lib/client/index.js';
describe('RedisSocketProxy', function () {
testUtils.testWithClient('basic proxy functionality', async (client: RedisClientType<any, any, any, any, any>) => {
const socketOptions = client?.options?.socket;
//@ts-ignore
assert(socketOptions?.port, 'Test requires a TCP connection to Redis');
const proxyPort = 50000 + Math.floor(Math.random() * 10000);
const proxy = new RedisProxy({
listenHost: '127.0.0.1',
listenPort: proxyPort,
//@ts-ignore
targetPort: socketOptions.port,
//@ts-ignore
targetHost: socketOptions.host || '127.0.0.1',
enableLogging: true
});
const proxyEvents = {
connections: [] as any[],
dataTransfers: [] as any[]
};
proxy.on('connection', (connectionInfo) => {
proxyEvents.connections.push(connectionInfo);
});
proxy.on('data', (connectionId, direction, data) => {
proxyEvents.dataTransfers.push({ connectionId, direction, dataLength: data.length });
});
try {
await proxy.start();
const proxyClient = client.duplicate({
socket: {
port: proxyPort,
host: '127.0.0.1'
},
});
await proxyClient.connect();
const stats = proxy.getStats();
assert.equal(stats.activeConnections, 1, 'Should have one active connection');
assert.equal(proxyEvents.connections.length, 1, 'Should have recorded one connection event');
const pingResult = await proxyClient.ping();
assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy');
const clientToServerTransfers = proxyEvents.dataTransfers.filter(t => t.direction === 'client->server');
const serverToClientTransfers = proxyEvents.dataTransfers.filter(t => t.direction === 'server->client');
assert(clientToServerTransfers.length > 0, 'Should have client->server data transfers');
assert(serverToClientTransfers.length > 0, 'Should have server->client data transfers');
const testKey = `test:proxy:${Date.now()}`;
const testValue = 'proxy-test-value';
await proxyClient.set(testKey, testValue);
const retrievedValue = await proxyClient.get(testKey);
assert.equal(retrievedValue, testValue, 'Should be able to set and get values through proxy');
proxyClient.destroy();
} finally {
await proxy.stop();
}
}, GLOBAL.SERVERS.OPEN_RESP_3);
testUtils.testWithProxiedClient('custom message injection via proxy client',
async (proxiedClient: RedisClientType<any, any, any, any, any>, proxy: RedisProxy) => {
const customMessageTransfers: any[] = [];
proxy.on('data', (connectionId, direction, data) => {
if (direction === 'server->client') {
customMessageTransfers.push({ connectionId, dataLength: data.length, data });
}
});
const stats = proxy.getStats();
assert.equal(stats.activeConnections, 1, 'Should have one active connection');
// Send a resp3 push
const customMessage = Buffer.from('>4\r\n$6\r\nMOVING\r\n:1\r\n:2\r\n$6\r\nhost:3\r\n');
const sendResults = proxy.sendToAllClients(customMessage);
assert.equal(sendResults.length, 1, 'Should send to one client');
assert.equal(sendResults[0].success, true, 'Custom message send should succeed');
const customMessageFound = customMessageTransfers.find(transfer =>
transfer.dataLength === customMessage.length
);
assert(customMessageFound, 'Should have recorded the custom message transfer');
assert.equal(customMessageFound.dataLength, customMessage.length,
'Custom message length should match');
const pingResult = await proxiedClient.ping();
assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy');
}, GLOBAL.SERVERS.OPEN_RESP_3);
describe("Middleware", () => {
testUtils.testWithProxiedClient(
"Modify request/response via middleware",
async (
proxiedClient: RedisClientType<any, any, any, any, any>,
proxy: RedisProxy,
) => {
// Intercept PING commands and modify the response
const pingInterceptor: InterceptorFunction = async (data, next) => {
if (data.includes('PING')) {
return Buffer.from("+PINGINTERCEPTED\r\n");
}
return next(data);
};
// Only intercept GET responses and double numeric values
// Does not modify other commands or non-numeric GET responses
const doubleNumberGetInterceptor: InterceptorFunction = async (data, next) => {
const response = await next(data);
// Not a GET command, return original response
if (!data.includes("GET")) return response;
const value = (response.toString().split("\r\n"))[1];
const number = Number(value);
// Not a number, return original response
if(isNaN(number)) return response;
const doubled = String(number * 2);
return Buffer.from(`$${doubled.length}\r\n${doubled}\r\n`);
};
proxy.setInterceptors([ pingInterceptor, doubleNumberGetInterceptor ])
const pingResponse = await proxiedClient.ping();
assert.equal(pingResponse, 'PINGINTERCEPTED', 'Response should be modified by middleware');
await proxiedClient.set('foo', 1);
const getResponse1 = await proxiedClient.get('foo');
assert.equal(getResponse1, '2', 'GET response should be doubled for numbers by middleware');
await proxiedClient.set('bar', 'Hi');
const getResponse2 = await proxiedClient.get('bar');
assert.equal(getResponse2, 'Hi', 'GET response should not be modified for strings by middleware');
await proxiedClient.hSet('baz', 'foo', 'dictvalue');
const hgetResponse = await proxiedClient.hGet('baz', 'foo');
assert.equal(hgetResponse, 'dictvalue', 'HGET response should not be modified by middleware');
},
GLOBAL.SERVERS.OPEN_RESP_3,
);
});
});