mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
add tests
This commit is contained in:
79
redis.go
79
redis.go
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -308,8 +309,15 @@ func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err
|
||||
// we can get it from the *Conn and remove it from the clients pool.
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
poolCn, _ := cn.connPool.Get(ctx)
|
||||
c.connPool.Remove(ctx, poolCn, err)
|
||||
poolCn, getErr := cn.connPool.Get(ctx)
|
||||
if getErr == nil {
|
||||
c.connPool.Remove(ctx, poolCn, err)
|
||||
} else {
|
||||
// if we can't get the pool connection, we can only close the connection
|
||||
if err := cn.Close(); err != nil {
|
||||
log.Printf("failed to close connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -344,7 +352,20 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
var err error
|
||||
cn.Inited = true
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool)
|
||||
var parentHooks hooksMixin
|
||||
pH := ctx.Value(internal.ParentHooksMixinKey{})
|
||||
switch pH := pH.(type) {
|
||||
case nil:
|
||||
parentHooks = hooksMixin{}
|
||||
case hooksMixin:
|
||||
parentHooks = pH.clone()
|
||||
case *hooksMixin:
|
||||
parentHooks = (*pH).clone()
|
||||
default:
|
||||
parentHooks = hooksMixin{}
|
||||
}
|
||||
|
||||
conn := newConn(c.opt, connPool, parentHooks)
|
||||
|
||||
protocol := c.opt.Protocol
|
||||
// By default, use RESP3 in current version.
|
||||
@ -352,28 +373,30 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
protocol = 3
|
||||
}
|
||||
|
||||
var authenticated bool
|
||||
username, password := c.opt.Username, c.opt.Password
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
}
|
||||
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
||||
return err
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||
}
|
||||
} else if c.opt.CredentialsProvider != nil {
|
||||
username, password = c.opt.CredentialsProvider()
|
||||
} else if c.opt.Username != "" || c.opt.Password != "" {
|
||||
username, password = c.opt.Username, c.opt.Password
|
||||
}
|
||||
|
||||
// for redis-server versions that do not support the HELLO command,
|
||||
// RESP2 will continue to be used.
|
||||
if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
authenticated = true
|
||||
// Authentication successful with HELLO command
|
||||
} else if !isRedisError(err) {
|
||||
// When the server responds with the RESP protocol and the result is not a normal
|
||||
// execution result of the HELLO command, we consider it to be an indication that
|
||||
@ -382,15 +405,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
||||
// with different error string results for unsupported commands, making it
|
||||
// difficult to rely on error strings to determine all results.
|
||||
return err
|
||||
}
|
||||
|
||||
if !authenticated && password != "" {
|
||||
return fmt.Errorf("failed to initialize connection: %w", err)
|
||||
} else if password != "" {
|
||||
// Try legacy AUTH command if HELLO failed
|
||||
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
if c.opt.DB > 0 {
|
||||
pipe.Select(ctx, c.opt.DB)
|
||||
@ -407,7 +430,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
}
|
||||
|
||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||
@ -422,13 +445,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
||||
// out of order responses later on.
|
||||
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
|
||||
return err
|
||||
return fmt.Errorf("failed to set client identity: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -547,6 +571,16 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
||||
return c.opt.ReadTimeout
|
||||
}
|
||||
|
||||
// context returns the context for the current connection.
|
||||
// If the context timeout is enabled, it returns the original context.
|
||||
// Otherwise, it returns a new background context.
|
||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
if c.opt.ContextTimeoutEnabled {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// Close closes the client, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Client, as the Client is meant to be
|
||||
@ -699,13 +733,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
if c.opt.ContextTimeoutEnabled {
|
||||
return ctx
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Client is a Redis client representing a pool of zero or more underlying connections.
|
||||
@ -752,7 +779,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Conn() *Conn {
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
|
||||
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.hooksMixin.clone())
|
||||
}
|
||||
|
||||
// Do create a Cmd from the args and processes the cmd.
|
||||
@ -763,6 +790,7 @@ func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
|
||||
}
|
||||
|
||||
func (c *Client) Process(ctx context.Context, cmd Cmder) error {
|
||||
ctx = context.WithValue(ctx, internal.ParentHooksMixinKey{}, c.hooksMixin)
|
||||
err := c.processHook(ctx, cmd)
|
||||
cmd.SetErr(err)
|
||||
return err
|
||||
@ -888,7 +916,7 @@ type Conn struct {
|
||||
hooksMixin
|
||||
}
|
||||
|
||||
func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
||||
func newConn(opt *Options, connPool pool.Pooler, parentHooks hooksMixin) *Conn {
|
||||
c := Conn{
|
||||
baseClient: baseClient{
|
||||
opt: opt,
|
||||
@ -898,6 +926,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
|
||||
|
||||
c.cmdable = c.Process
|
||||
c.statefulCmdable = c.Process
|
||||
c.hooksMixin = parentHooks
|
||||
c.initHooks(hooks{
|
||||
dial: c.baseClient.dial,
|
||||
process: c.baseClient.process,
|
||||
|
Reference in New Issue
Block a user