1
0
mirror of https://github.com/redis/node-redis.git synced 2025-12-12 21:21:15 +03:00

feat(proxy): implement express style middleware (#3105)

This commit is contained in:
Nikolay Karadzhov
2025-10-21 14:47:37 +03:00
committed by GitHub
parent 1cda848393
commit b8267c9b82
3 changed files with 102 additions and 9 deletions

View File

@@ -313,11 +313,10 @@ export default class TestUtils {
//@ts-ignore //@ts-ignore
targetPort: socketOptions.port, targetPort: socketOptions.port,
//@ts-ignore //@ts-ignore
targetHost: socketOptions.host, targetHost: socketOptions.host ?? '127.0.0.1',
enableLogging: true enableLogging: true
}); });
await proxy.start(); await proxy.start();
const proxyClient = client.duplicate({ const proxyClient = client.duplicate({
socket: { socket: {

View File

@@ -1,7 +1,7 @@
import { strict as assert } from 'node:assert'; import { strict as assert } from 'node:assert';
import { Buffer } from 'node:buffer'; import { Buffer } from 'node:buffer';
import { testUtils, GLOBAL } from './test-utils'; import { testUtils, GLOBAL } from './test-utils';
import { RedisProxy } from './redis-proxy'; import { InterceptorFunction, RedisProxy } from './redis-proxy';
import type { RedisClientType } from '@redis/client/lib/client/index.js'; import type { RedisClientType } from '@redis/client/lib/client/index.js';
describe('RedisSocketProxy', function () { describe('RedisSocketProxy', function () {
@@ -107,5 +107,61 @@ describe('RedisSocketProxy', function () {
const pingResult = await proxiedClient.ping(); const pingResult = await proxiedClient.ping();
assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy'); assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy');
}, GLOBAL.SERVERS.OPEN_RESP_3) }, GLOBAL.SERVERS.OPEN_RESP_3);
describe("Middleware", () => {
testUtils.testWithProxiedClient(
"Modify request/response via middleware",
async (
proxiedClient: RedisClientType<any, any, any, any, any>,
proxy: RedisProxy,
) => {
// Intercept PING commands and modify the response
const pingInterceptor: InterceptorFunction = async (data, next) => {
if (data.includes('PING')) {
return Buffer.from("+PINGINTERCEPTED\r\n");
}
return next(data);
};
// Only intercept GET responses and double numeric values
// Does not modify other commands or non-numeric GET responses
const doubleNumberGetInterceptor: InterceptorFunction = async (data, next) => {
const response = await next(data);
// Not a GET command, return original response
if (!data.includes("GET")) return response;
const value = (response.toString().split("\r\n"))[1];
const number = Number(value);
// Not a number, return original response
if(isNaN(number)) return response;
const doubled = String(number * 2);
return Buffer.from(`$${doubled.length}\r\n${doubled}\r\n`);
};
proxy.setInterceptors([ pingInterceptor, doubleNumberGetInterceptor ])
const pingResponse = await proxiedClient.ping();
assert.equal(pingResponse, 'PINGINTERCEPTED', 'Response should be modified by middleware');
await proxiedClient.set('foo', 1);
const getResponse1 = await proxiedClient.get('foo');
assert.equal(getResponse1, '2', 'GET response should be doubled for numbers by middleware');
await proxiedClient.set('bar', 'Hi');
const getResponse2 = await proxiedClient.get('bar');
assert.equal(getResponse2, 'Hi', 'GET response should not be modified for strings by middleware');
await proxiedClient.hSet('baz', 'foo', 'dictvalue');
const hgetResponse = await proxiedClient.hGet('baz', 'foo');
assert.equal(hgetResponse, 'dictvalue', 'HGET response should not be modified by middleware');
},
GLOBAL.SERVERS.OPEN_RESP_3,
);
});
}); });

View File

@@ -20,6 +20,7 @@ interface ConnectionInfo {
interface ActiveConnection extends ConnectionInfo { interface ActiveConnection extends ConnectionInfo {
readonly clientSocket: net.Socket; readonly clientSocket: net.Socket;
readonly serverSocket: net.Socket; readonly serverSocket: net.Socket;
inflightRequestsCount: number
} }
type SendResult = type SendResult =
@@ -49,11 +50,16 @@ interface ProxyEvents {
'close': () => void; 'close': () => void;
} }
export type Interceptor = (data: Buffer) => Promise<Buffer>;
export type InterceptorFunction = (data: Buffer, next: Interceptor) => Promise<Buffer>;
type InterceptorInitializer = (init: Interceptor) => Interceptor;
export class RedisProxy extends EventEmitter { export class RedisProxy extends EventEmitter {
private readonly server: net.Server; private readonly server: net.Server;
public readonly config: Required<ProxyConfig>; public readonly config: Required<ProxyConfig>;
private readonly connections: Map<string, ActiveConnection>; private readonly connections: Map<string, ActiveConnection>;
private isRunning: boolean; private isRunning: boolean;
private interceptorInitializer: InterceptorInitializer = (init) => init;
constructor(config: ProxyConfig) { constructor(config: ProxyConfig) {
super(); super();
@@ -113,6 +119,13 @@ export class RedisProxy extends EventEmitter {
}); });
} }
public setInterceptors(interceptors: Array<InterceptorFunction>) {
this.interceptorInitializer = (init) => interceptors.reduceRight<Interceptor>(
(next, mw) => (data) => mw(data, next),
init
);
}
public getStats(): ProxyStats { public getStats(): ProxyStats {
const connections = Array.from(this.connections.values()); const connections = Array.from(this.connections.values());
@@ -218,19 +231,22 @@ export class RedisProxy extends EventEmitter {
} }
private handleClientConnection(clientSocket: net.Socket): void { private handleClientConnection(clientSocket: net.Socket): void {
const connectionId = this.generateConnectionId(); clientSocket.pause();
const serverSocket = net.createConnection({ const serverSocket = net.createConnection({
host: this.config.targetHost, host: this.config.targetHost,
port: this.config.targetPort port: this.config.targetPort
}); });
serverSocket.once('connect', clientSocket.resume.bind(clientSocket));
const connectionId = this.generateConnectionId();
const connectionInfo: ActiveConnection = { const connectionInfo: ActiveConnection = {
id: connectionId, id: connectionId,
clientAddress: clientSocket.remoteAddress || 'unknown', clientAddress: clientSocket.remoteAddress || 'unknown',
clientPort: clientSocket.remotePort || 0, clientPort: clientSocket.remotePort || 0,
connectedAt: new Date(), connectedAt: new Date(),
clientSocket, clientSocket,
serverSocket serverSocket,
inflightRequestsCount: 0
}; };
this.connections.set(connectionId, connectionInfo); this.connections.set(connectionId, connectionInfo);
@@ -243,12 +259,33 @@ export class RedisProxy extends EventEmitter {
this.emit('connection', connectionInfo); this.emit('connection', connectionInfo);
}); });
clientSocket.on('data', (data) => { clientSocket.on('data', async (data) => {
this.emit('data', connectionId, 'client->server', data); this.emit('data', connectionId, 'client->server', data);
serverSocket.write(data);
connectionInfo.inflightRequestsCount++;
// next1 -> next2 -> ... -> last -> server
// next1 <- next2 <- ... <- last <- server
const last = (data: Buffer): Promise<Buffer> => {
return new Promise((resolve, reject) => {
serverSocket.write(data);
serverSocket.once('data', (data) => {
connectionInfo.inflightRequestsCount--;
assert(connectionInfo.inflightRequestsCount >= 0, `inflightRequestsCount for connection ${connectionId} went below zero`);
this.emit('data', connectionId, 'server->client', data);
resolve(data);
});
serverSocket.once('error', reject);
});
};
const interceptorChain = this.interceptorInitializer(last);
const response = await interceptorChain(data);
clientSocket.write(response);
}); });
serverSocket.on('data', (data) => { serverSocket.on('data', (data) => {
if (connectionInfo.inflightRequestsCount > 0) return;
this.emit('data', connectionId, 'server->client', data); this.emit('data', connectionId, 'server->client', data);
clientSocket.write(data); clientSocket.write(data);
}); });
@@ -273,6 +310,7 @@ export class RedisProxy extends EventEmitter {
}); });
serverSocket.on('error', (error) => { serverSocket.on('error', (error) => {
if (connectionInfo.inflightRequestsCount > 0) return;
this.log(`Server error for connection ${connectionId}: ${error.message}`); this.log(`Server error for connection ${connectionId}: ${error.message}`);
this.emit('error', error, connectionId); this.emit('error', error, connectionId);
clientSocket.destroy(); clientSocket.destroy();
@@ -306,6 +344,7 @@ export class RedisProxy extends EventEmitter {
} }
} }
import { createServer } from 'net'; import { createServer } from 'net';
import assert from 'node:assert';
export function getFreePortNumber(): Promise<number> { export function getFreePortNumber(): Promise<number> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
@@ -326,4 +365,3 @@ export function getFreePortNumber(): Promise<number> {
export { RedisProxy as RedisTransparentProxy }; export { RedisProxy as RedisTransparentProxy };
export type { ProxyConfig, ConnectionInfo, ProxyEvents, SendResult, DataDirection, ProxyStats }; export type { ProxyConfig, ConnectionInfo, ProxyEvents, SendResult, DataDirection, ProxyStats };