mirror of
https://github.com/redis/go-redis.git
synced 2025-04-19 07:22:17 +03:00
Initial re authentication implementation
Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested
This commit is contained in:
parent
847f1f9daa
commit
40a89c56cc
@ -9,6 +9,7 @@ type StreamingCredentialsProvider interface {
|
|||||||
// Subscribe subscribes to the credentials provider for updates.
|
// Subscribe subscribes to the credentials provider for updates.
|
||||||
// It returns the current credentials, a cancel function to unsubscribe from the provider,
|
// It returns the current credentials, a cancel function to unsubscribe from the provider,
|
||||||
// and an error if any.
|
// and an error if any.
|
||||||
|
// TODO(ndyakov): Should we add context to the Subscribe method?
|
||||||
Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error)
|
Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
45
auth/reauth_credentials_listener.go
Normal file
45
auth/reauth_credentials_listener.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
|
||||||
|
// It is used to re-authenticate the credentials when they are updated.
|
||||||
|
// It contains:
|
||||||
|
// - reAuth: a function that takes the new credentials and returns an error if any.
|
||||||
|
// - onErr: a function that takes an error and handles it.
|
||||||
|
type ReAuthCredentialsListener struct {
|
||||||
|
reAuth func(credentials Credentials) error
|
||||||
|
onErr func(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnNext is called when the credentials are updated.
|
||||||
|
// It calls the reAuth function with the new credentials.
|
||||||
|
// If the reAuth function returns an error, it calls the onErr function with the error.
|
||||||
|
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
|
||||||
|
if c.reAuth != nil {
|
||||||
|
err := c.reAuth(credentials)
|
||||||
|
if err != nil {
|
||||||
|
if c.onErr != nil {
|
||||||
|
c.onErr(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnError is called when an error occurs.
|
||||||
|
// It can be called from both the credentials provider and the reAuth function.
|
||||||
|
func (c *ReAuthCredentialsListener) OnError(err error) {
|
||||||
|
if c.onErr != nil {
|
||||||
|
c.onErr(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
|
||||||
|
// Implements the auth.CredentialsListener interface.
|
||||||
|
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
|
||||||
|
return &ReAuthCredentialsListener{
|
||||||
|
reAuth: reAuth,
|
||||||
|
onErr: onErr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
|
||||||
|
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
|
@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewClient: func(opt *Options) *Client {
|
NewClient: func(opt *Options) *Client {
|
||||||
c := NewClient(opt)
|
c := NewClient(opt)
|
||||||
c.baseClient.onClose = func() error {
|
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
|
||||||
closeCounter.increment(opt.Addr)
|
closeCounter.increment(opt.Addr)
|
||||||
return nil
|
return nil
|
||||||
}
|
})
|
||||||
return c
|
return c
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
createCounter.increment(opt.Addr)
|
createCounter.increment(opt.Addr)
|
||||||
c := NewClient(opt)
|
c := NewClient(opt)
|
||||||
c.baseClient.onClose = func() error {
|
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
|
||||||
closeCounter.increment(opt.Addr)
|
closeCounter.increment(opt.Addr)
|
||||||
return nil
|
return nil
|
||||||
}
|
})
|
||||||
return c
|
return c
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
56
redis.go
56
redis.go
@ -283,7 +283,15 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
|||||||
return cn, nil
|
return cn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) reAuth(ctx context.Context, cn *Conn, credentials auth.Credentials) error {
|
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, conn *Conn) auth.CredentialsListener {
|
||||||
|
return auth.NewReAuthCredentialsListener(
|
||||||
|
c.reAuthConnection(c.context(ctx), conn),
|
||||||
|
c.onAuthenticationErr(c.context(ctx), conn),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error {
|
||||||
|
return func(credentials auth.Credentials) error {
|
||||||
var err error
|
var err error
|
||||||
username, password := credentials.BasicAuth()
|
username, password := credentials.BasicAuth()
|
||||||
if username != "" {
|
if username != "" {
|
||||||
@ -292,6 +300,40 @@ func (c *baseClient) reAuth(ctx context.Context, cn *Conn, credentials auth.Cred
|
|||||||
err = cn.Auth(ctx, password).Err()
|
err = cn.Auth(ctx, password).Err()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) {
|
||||||
|
return func(err error) {
|
||||||
|
// since the connection pool of the *Conn will actually return us the underlying pool.Conn,
|
||||||
|
// we can get it from the *Conn and remove it from the clients pool.
|
||||||
|
if err != nil {
|
||||||
|
if isBadConn(err, false, c.opt.Addr) {
|
||||||
|
poolCn, _ := cn.connPool.Get(ctx)
|
||||||
|
c.connPool.Remove(ctx, poolCn, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||||
|
onClose := c.onClose
|
||||||
|
return func() error {
|
||||||
|
var firstErr error
|
||||||
|
err := newOnClose()
|
||||||
|
// Even if we have an error we would like to execute the onClose hook
|
||||||
|
// if it exists. We will return the first error that occurred.
|
||||||
|
// This is to keep error handling consistent with the rest of the code.
|
||||||
|
if err != nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
if onClose != nil {
|
||||||
|
err = onClose()
|
||||||
|
if err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||||
@ -312,7 +354,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
|
|
||||||
var authenticated bool
|
var authenticated bool
|
||||||
username, password := c.opt.Username, c.opt.Password
|
username, password := c.opt.Username, c.opt.Password
|
||||||
if c.opt.CredentialsProviderContext != nil {
|
if c.opt.StreamingCredentialsProvider != nil {
|
||||||
|
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||||
|
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
|
||||||
|
username, password = credentials.BasicAuth()
|
||||||
|
} else if c.opt.CredentialsProviderContext != nil {
|
||||||
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -336,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !authenticated && password != "" {
|
if !authenticated && password != "" {
|
||||||
err = c.reAuth(ctx, conn, auth.NewBasicCredentials(username, password))
|
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -257,7 +257,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
|||||||
|
|
||||||
connPool = newConnPool(opt, rdb.dialHook)
|
connPool = newConnPool(opt, rdb.dialHook)
|
||||||
rdb.connPool = connPool
|
rdb.connPool = connPool
|
||||||
rdb.onClose = failover.Close
|
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||||
|
|
||||||
failover.mu.Lock()
|
failover.mu.Lock()
|
||||||
failover.onFailover = func(ctx context.Context, addr string) {
|
failover.onFailover = func(ctx context.Context, addr string) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user