1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-18 00:20:57 +03:00
This commit is contained in:
Nedyalko Dyakov
2025-07-07 18:18:37 +03:00
parent 225c0bf5b2
commit e697fcc76b
14 changed files with 2173 additions and 29 deletions

193
redis.go
View File

@ -2,6 +2,7 @@ package redis
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
@ -211,6 +212,9 @@ type baseClient struct {
// Push notification processing
pushProcessor push.NotificationProcessor
// hitlessIntegration provides hitless upgrade functionality
hitlessIntegration HitlessIntegration
}
func (c *baseClient) clone() *baseClient {
@ -466,7 +470,16 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
}
c.connPool.Put(ctx, cn)
// Check if connection is marked for closing due to hitless upgrades
if c.hitlessIntegration != nil && c.hitlessIntegration.IsConnectionMarkedForClosing(cn) {
// Connection is marked for closing (e.g., during MOVING state)
// Remove it instead of putting it back in the pool
internal.Logger.Printf(ctx, "hitless: closing connection marked for closure during upgrade")
c.connPool.Remove(ctx, cn, nil)
} else {
c.connPool.Put(ctx, cn)
}
}
}
@ -528,7 +541,33 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
}
retryTimeout := uint32(0)
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
// Check if this is a blocking command that needs redirection
isBlockingCommand := cmd.readTimeout() != nil
var redirectedConn *pool.Conn
var shouldRedirect bool
var newEndpoint string
if c.hitlessIntegration != nil && isBlockingCommand {
// For blocking commands, check if we need to redirect to a new endpoint
// This happens during MOVING state when the endpoint is changing
shouldRedirect, newEndpoint = c.hitlessIntegration.ShouldRedirectBlockingConnection(nil)
if shouldRedirect {
// Create a new connection to the new endpoint
var err error
redirectedConn, err = c.createConnectionToEndpoint(ctx, newEndpoint)
if err != nil {
internal.Logger.Printf(ctx, "hitless: failed to create redirected connection to %s: %v", newEndpoint, err)
// Fall back to normal connection if redirection fails
shouldRedirect = false
} else {
internal.Logger.Printf(ctx, "hitless: redirecting blocking command %s to new endpoint %s", cmd.Name(), newEndpoint)
}
}
}
// Use redirected connection if available, otherwise use normal connection
connFunc := func(ctx context.Context, cn *pool.Conn) error {
// Process any pending push notifications before executing the command
if err := c.processPushNotifications(ctx, cn); err != nil {
// Log the error but don't fail the command execution
@ -536,7 +575,19 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
// Mark connection as blocking if this is a blocking command
if c.hitlessIntegration != nil && isBlockingCommand {
c.hitlessIntegration.MarkConnectionAsBlocking(cn, true)
internal.Logger.Printf(ctx, "hitless: marked connection as blocking for command %s", cmd.Name())
}
// Get appropriate write timeout for this connection
writeTimeout := c.opt.WriteTimeout
if c.hitlessIntegration != nil {
_, writeTimeout = c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
}
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
}); err != nil {
atomic.StoreUint32(&retryTimeout, 1)
@ -547,7 +598,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
readReplyFunc = cmd.readRawReply
}
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error {
if err := cn.WithReader(c.context(ctx), c.cmdTimeoutForConnection(cmd, cn), func(rd *proto.Reader) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
@ -564,8 +615,31 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
return err
}
// Unmark connection as blocking after command completes
if c.hitlessIntegration != nil && isBlockingCommand {
c.hitlessIntegration.MarkConnectionAsBlocking(cn, false)
internal.Logger.Printf(ctx, "hitless: unmarked connection as blocking after command %s completed", cmd.Name())
}
return nil
}); err != nil {
}
// Execute the command with either redirected or normal connection
var err error
if shouldRedirect && redirectedConn != nil {
// Use the redirected connection for blocking command
err = connFunc(ctx, redirectedConn)
// Close the redirected connection after use
defer func() {
redirectedConn.Close()
internal.Logger.Printf(ctx, "hitless: closed redirected connection to %s after blocking command completed", newEndpoint)
}()
} else {
// Use normal connection pool
err = c.withConn(ctx, connFunc)
}
if err != nil {
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
return retry, err
}
@ -588,6 +662,70 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
return c.opt.ReadTimeout
}
// cmdTimeoutForConnection returns the appropriate read timeout for a specific connection
// taking into account hitless upgrade state
func (c *baseClient) cmdTimeoutForConnection(cmd Cmder, cn *pool.Conn) time.Duration {
baseTimeout := c.cmdTimeout(cmd)
// If hitless upgrades are enabled, get dynamic timeout based on connection state
if c.hitlessIntegration != nil {
// For blocking commands, use the command's timeout but check if connection needs increased timeout
if cmd.readTimeout() != nil {
// For blocking commands, use the base timeout but apply hitless upgrade adjustments
adjustedTimeout := c.hitlessIntegration.GetConnectionTimeout(cn, baseTimeout)
return adjustedTimeout
} else {
// For regular commands, get both read and write timeouts (use read timeout for command)
readTimeout, _ := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
return readTimeout
}
}
return baseTimeout
}
// createConnectionToEndpoint creates a new connection to a specific endpoint
// This is used for redirecting blocking commands during MOVING state
func (c *baseClient) createConnectionToEndpoint(ctx context.Context, endpoint string) (*pool.Conn, error) {
// Parse the endpoint to get host and port
addr := endpoint
if addr == "" {
return nil, fmt.Errorf("empty endpoint provided")
}
// Create a temporary dialer for the new endpoint
dialer := func(ctx context.Context) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: c.opt.DialTimeout,
KeepAlive: 30 * time.Second,
}
if c.opt.TLSConfig == nil {
return netDialer.DialContext(ctx, c.opt.Network, addr)
}
return tls.DialWithDialer(netDialer, c.opt.Network, addr, c.opt.TLSConfig)
}
// Create a new connection using the dialer
netConn, err := dialer(ctx)
if err != nil {
return nil, fmt.Errorf("failed to dial new endpoint %s: %w", endpoint, err)
}
// Wrap in pool.Conn
cn := pool.NewConn(netConn)
// Initialize the connection (auth, select db, etc.)
if err := c.initConn(ctx, cn); err != nil {
cn.Close()
return nil, fmt.Errorf("failed to initialize connection to %s: %w", endpoint, err)
}
internal.Logger.Printf(ctx, "hitless: created new connection to endpoint %s for blocking command redirection", endpoint)
return cn, nil
}
// context returns the context for the current connection.
// If the context timeout is enabled, it returns the original context.
// Otherwise, it returns a new background context.
@ -604,8 +742,19 @@ func (c *baseClient) context(ctx context.Context) context.Context {
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
// Close hitless integration first
if c.hitlessIntegration != nil {
if err := c.hitlessIntegration.Close(); err != nil {
internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err)
if firstErr == nil {
firstErr = err
}
}
}
if c.onClose != nil {
if err := c.onClose(); err != nil {
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
@ -678,14 +827,15 @@ func (c *baseClient) pipelineProcessCmds(
internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
setCmdsErr(cmds, err)
return true, err
}
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
// read all replies
return c.pipelineReadCmds(ctx, cn, rd, cmds)
}); err != nil {
@ -725,14 +875,15 @@ func (c *baseClient) txPipelineProcessCmds(
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout)
if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
setCmdsErr(cmds, err)
return true, err
}
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error {
statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec.
trimmedCmds := cmds[1 : len(cmds)-1]
@ -837,6 +988,22 @@ func NewClient(opt *Options) *Client {
c.connPool = newConnPool(opt, c.dialHook)
// Initialize hitless upgrades if enabled
if opt.HitlessUpgrades {
if opt.Protocol != 3 {
internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol)
} else {
timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout)
integration, err := initializeHitlessIntegration(&c, opt.HitlessUpgradeConfig, timeoutProvider)
if err != nil {
internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err)
} else {
c.hitlessIntegration = integration
internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for client")
}
}
}
return &c
}
@ -857,6 +1024,12 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
return &clone
}
// GetHitlessIntegration returns the hitless integration instance for monitoring and control.
// Returns nil if hitless upgrades are not enabled.
func (c *Client) GetHitlessIntegration() HitlessIntegration {
return c.hitlessIntegration
}
func (c *Client) Conn() *Conn {
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
}