mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
fix(auth): streamline auth err proccess
This commit is contained in:
41
redis.go
41
redis.go
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -285,21 +284,22 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, cn *pool.Conn) auth.CredentialsListener {
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
// hooksMixin are intentionally empty here
|
||||
conn := newConn(c.opt, connPool, nil)
|
||||
ctx = c.context(ctx)
|
||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||
return auth.NewReAuthCredentialsListener(
|
||||
c.reAuthConnection(ctx, conn),
|
||||
c.onAuthenticationErr(ctx, conn),
|
||||
c.reAuthConnection(poolCn),
|
||||
c.onAuthenticationErr(poolCn),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error {
|
||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||
return func(credentials auth.Credentials) error {
|
||||
var err error
|
||||
username, password := credentials.BasicAuth()
|
||||
ctx := context.Background()
|
||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||
// hooksMixin are intentionally empty here
|
||||
cn := newConn(c.opt, connPool, nil)
|
||||
|
||||
if username != "" {
|
||||
err = cn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
@ -308,22 +308,13 @@ func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(creden
|
||||
return err
|
||||
}
|
||||
}
|
||||
func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) {
|
||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||
return func(err error) {
|
||||
// since the connection pool of the *Conn will actually return us the underlying pool.Conn,
|
||||
// we can get it from the *Conn and remove it from the clients pool.
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
poolCn, getErr := cn.connPool.Get(ctx)
|
||||
if getErr == nil {
|
||||
c.connPool.Remove(ctx, poolCn, err)
|
||||
} else {
|
||||
// if we can't get the pool connection, we can only close the connection
|
||||
if err := cn.Close(); err != nil {
|
||||
log.Printf("failed to close connection: %v", err)
|
||||
}
|
||||
}
|
||||
c.connPool.CloseConn(poolCn)
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -368,7 +359,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(ctx, cn))
|
||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
}
|
||||
@ -401,7 +392,11 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
return err
|
||||
} else if password != "" {
|
||||
// Try legacy AUTH command if HELLO failed
|
||||
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
|
||||
if username != "" {
|
||||
err = conn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = conn.Auth(ctx, password).Err()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user