From 52286114a3880357d9ed8e7eb24de1fdf85d4c39 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 14 May 2025 10:28:54 +0300 Subject: [PATCH] fix(auth): streamline auth err proccess --- redis.go | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/redis.go b/redis.go index 22b52981..32379d59 100644 --- a/redis.go +++ b/redis.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "log" "net" "sync" "sync/atomic" @@ -285,21 +284,22 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, cn *pool.Conn) auth.CredentialsListener { - connPool := pool.NewSingleConnPool(c.connPool, cn) - // hooksMixin are intentionally empty here - conn := newConn(c.opt, connPool, nil) - ctx = c.context(ctx) +func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener { return auth.NewReAuthCredentialsListener( - c.reAuthConnection(ctx, conn), - c.onAuthenticationErr(ctx, conn), + c.reAuthConnection(poolCn), + c.onAuthenticationErr(poolCn), ) } -func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error { +func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error { return func(credentials auth.Credentials) error { var err error username, password := credentials.BasicAuth() + ctx := context.Background() + connPool := pool.NewSingleConnPool(c.connPool, poolCn) + // hooksMixin are intentionally empty here + cn := newConn(c.opt, connPool, nil) + if username != "" { err = cn.AuthACL(ctx, username, password).Err() } else { @@ -308,22 +308,13 @@ func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(creden return err } } -func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) { +func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) { return func(err error) { - // since the connection pool of the *Conn will actually return us the underlying pool.Conn, - // we can get it from the *Conn and remove it from the clients pool. if err != nil { if isBadConn(err, false, c.opt.Addr) { - poolCn, getErr := cn.connPool.Get(ctx) - if getErr == nil { - c.connPool.Remove(ctx, poolCn, err) - } else { - // if we can't get the pool connection, we can only close the connection - if err := cn.Close(); err != nil { - log.Printf("failed to close connection: %v", err) - } - } + c.connPool.CloseConn(poolCn) } + internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err) } } } @@ -368,7 +359,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. - Subscribe(c.newReAuthCredentialsListener(ctx, cn)) + Subscribe(c.newReAuthCredentialsListener(cn)) if err != nil { return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) } @@ -401,7 +392,11 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return err } else if password != "" { // Try legacy AUTH command if HELLO failed - err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password)) + if username != "" { + err = conn.AuthACL(ctx, username, password).Err() + } else { + err = conn.Auth(ctx, password).Err() + } if err != nil { return fmt.Errorf("failed to authenticate: %w", err) }