1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

fix(auth): streamline auth err proccess

This commit is contained in:
Nedyalko Dyakov
2025-05-14 10:28:54 +03:00
parent a6a2c9d3b4
commit 52286114a3

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"log"
"net"
"sync"
"sync/atomic"
@ -285,21 +284,22 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, cn *pool.Conn) auth.CredentialsListener {
connPool := pool.NewSingleConnPool(c.connPool, cn)
// hooksMixin are intentionally empty here
conn := newConn(c.opt, connPool, nil)
ctx = c.context(ctx)
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
return auth.NewReAuthCredentialsListener(
c.reAuthConnection(ctx, conn),
c.onAuthenticationErr(ctx, conn),
c.reAuthConnection(poolCn),
c.onAuthenticationErr(poolCn),
)
}
func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error {
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
return func(credentials auth.Credentials) error {
var err error
username, password := credentials.BasicAuth()
ctx := context.Background()
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
// hooksMixin are intentionally empty here
cn := newConn(c.opt, connPool, nil)
if username != "" {
err = cn.AuthACL(ctx, username, password).Err()
} else {
@ -308,22 +308,13 @@ func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(creden
return err
}
}
func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) {
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
return func(err error) {
// since the connection pool of the *Conn will actually return us the underlying pool.Conn,
// we can get it from the *Conn and remove it from the clients pool.
if err != nil {
if isBadConn(err, false, c.opt.Addr) {
poolCn, getErr := cn.connPool.Get(ctx)
if getErr == nil {
c.connPool.Remove(ctx, poolCn, err)
} else {
// if we can't get the pool connection, we can only close the connection
if err := cn.Close(); err != nil {
log.Printf("failed to close connection: %v", err)
}
}
c.connPool.CloseConn(poolCn)
}
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
}
}
}
@ -368,7 +359,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(c.newReAuthCredentialsListener(ctx, cn))
Subscribe(c.newReAuthCredentialsListener(cn))
if err != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
}
@ -401,7 +392,11 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return err
} else if password != "" {
// Try legacy AUTH command if HELLO failed
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
if username != "" {
err = conn.AuthACL(ctx, username, password).Err()
} else {
err = conn.Auth(ctx, password).Err()
}
if err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
}