mirror of
https://github.com/redis/go-redis.git
synced 2025-07-31 05:04:23 +03:00
feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)
* wip * update documentation * add streamingcredentialsprovider in options * fix: put back option in pool creation * add package level comment * Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested * Change function type name Change CancelProviderFunc to UnsubscribeFunc * add tests * fix race in tests * fix example tests * wip, hooks refactor * fix build * update README.md * update wordlist * update README.md * refactor(auth): early returns in cred listener * fix(doctest): simulate some delay * feat(conn): add close hook on conn * fix(tests): simulate start/stop in mock credentials provider * fix(auth): don't double close the conn * docs(README): mark streaming credentials provider as experimental * fix(auth): streamline auth err proccess * fix(auth): check err on close conn * chore(entraid): use the repo under redis org
This commit is contained in:
61
auth/auth.go
Normal file
61
auth/auth.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Package auth package provides authentication-related interfaces and types.
|
||||
// It also includes a basic implementation of credentials using username and password.
|
||||
package auth
|
||||
|
||||
// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider.
|
||||
// It is used to provide credentials for authentication.
|
||||
// The CredentialsListener is used to receive updates when the credentials change.
|
||||
type StreamingCredentialsProvider interface {
|
||||
// Subscribe subscribes to the credentials provider for updates.
|
||||
// It returns the current credentials, a cancel function to unsubscribe from the provider,
|
||||
// and an error if any.
|
||||
// TODO(ndyakov): Should we add context to the Subscribe method?
|
||||
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
|
||||
}
|
||||
|
||||
// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider.
|
||||
// It is used to unsubscribe from the provider when the credentials are no longer needed.
|
||||
type UnsubscribeFunc func() error
|
||||
|
||||
// CredentialsListener is an interface that defines the methods for a credentials listener.
|
||||
// It is used to receive updates when the credentials change.
|
||||
// The OnNext method is called when the credentials change.
|
||||
// The OnError method is called when an error occurs while requesting the credentials.
|
||||
type CredentialsListener interface {
|
||||
OnNext(credentials Credentials)
|
||||
OnError(err error)
|
||||
}
|
||||
|
||||
// Credentials is an interface that defines the methods for credentials.
|
||||
// It is used to provide the credentials for authentication.
|
||||
type Credentials interface {
|
||||
// BasicAuth returns the username and password for basic authentication.
|
||||
BasicAuth() (username string, password string)
|
||||
// RawCredentials returns the raw credentials as a string.
|
||||
// This can be used to extract the username and password from the raw credentials or
|
||||
// additional information if present in the token.
|
||||
RawCredentials() string
|
||||
}
|
||||
|
||||
type basicAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
// RawCredentials returns the raw credentials as a string.
|
||||
func (b *basicAuth) RawCredentials() string {
|
||||
return b.username + ":" + b.password
|
||||
}
|
||||
|
||||
// BasicAuth returns the username and password for basic authentication.
|
||||
func (b *basicAuth) BasicAuth() (username string, password string) {
|
||||
return b.username, b.password
|
||||
}
|
||||
|
||||
// NewBasicCredentials creates a new Credentials object from the given username and password.
|
||||
func NewBasicCredentials(username, password string) Credentials {
|
||||
return &basicAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
308
auth/auth_test.go
Normal file
308
auth/auth_test.go
Normal file
@ -0,0 +1,308 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockStreamingProvider struct {
|
||||
credentials Credentials
|
||||
err error
|
||||
updates chan Credentials
|
||||
}
|
||||
|
||||
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
|
||||
return &mockStreamingProvider{
|
||||
credentials: initialCreds,
|
||||
updates: make(chan Credentials, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
|
||||
if m.err != nil {
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
// Send initial credentials
|
||||
listener.OnNext(m.credentials)
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}()
|
||||
|
||||
return m.credentials, func() error {
|
||||
close(m.updates)
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestStreamingCredentialsProvider(t *testing.T) {
|
||||
t.Run("successful subscription", func(t *testing.T) {
|
||||
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||
provider := newMockStreamingProvider(initialCreds)
|
||||
|
||||
var receivedCreds []Credentials
|
||||
var receivedErrors []error
|
||||
var mu sync.Mutex
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
mu.Lock()
|
||||
receivedCreds = append(receivedCreds, creds)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cancel == nil {
|
||||
t.Fatal("expected cancel function to be non-nil")
|
||||
}
|
||||
if creds != initialCreds {
|
||||
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||
}
|
||||
if len(receivedCreds) != 1 {
|
||||
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
|
||||
}
|
||||
if receivedCreds[0] != initialCreds {
|
||||
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
|
||||
}
|
||||
if len(receivedErrors) != 0 {
|
||||
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||
}
|
||||
|
||||
// Send an update
|
||||
newCreds := NewBasicCredentials("user2", "pass2")
|
||||
provider.updates <- newCreds
|
||||
|
||||
// Wait for update to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
if len(receivedCreds) != 2 {
|
||||
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
|
||||
}
|
||||
if receivedCreds[1] != newCreds {
|
||||
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Cancel subscription
|
||||
if err := cancel(); err != nil {
|
||||
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subscription error", func(t *testing.T) {
|
||||
provider := &mockStreamingProvider{
|
||||
err: errors.New("subscription failed"),
|
||||
}
|
||||
|
||||
var receivedCreds []Credentials
|
||||
var receivedErrors []error
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
receivedCreds = append(receivedCreds, creds)
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if cancel != nil {
|
||||
t.Fatal("expected cancel function to be nil")
|
||||
}
|
||||
if creds != nil {
|
||||
t.Fatalf("expected nil credentials, got %v", creds)
|
||||
}
|
||||
if len(receivedCreds) != 0 {
|
||||
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
|
||||
}
|
||||
if len(receivedErrors) != 0 {
|
||||
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-auth error", func(t *testing.T) {
|
||||
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||
provider := newMockStreamingProvider(initialCreds)
|
||||
|
||||
reauthErr := errors.New("re-auth failed")
|
||||
var receivedErrors []error
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
return reauthErr
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cancel == nil {
|
||||
t.Fatal("expected cancel function to be non-nil")
|
||||
}
|
||||
if creds != initialCreds {
|
||||
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||
}
|
||||
if len(receivedErrors) != 1 {
|
||||
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
|
||||
}
|
||||
if receivedErrors[0] != reauthErr {
|
||||
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
|
||||
}
|
||||
|
||||
if err := cancel(); err != nil {
|
||||
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBasicCredentials(t *testing.T) {
|
||||
t.Run("basic auth", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
username, password := creds.BasicAuth()
|
||||
if username != "user1" {
|
||||
t.Fatalf("expected username 'user1', got '%s'", username)
|
||||
}
|
||||
if password != "pass1" {
|
||||
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raw credentials", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
raw := creds.RawCredentials()
|
||||
expected := "user1:pass1"
|
||||
if raw != expected {
|
||||
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty username", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("", "pass1")
|
||||
username, password := creds.BasicAuth()
|
||||
if username != "" {
|
||||
t.Fatalf("expected empty username, got '%s'", username)
|
||||
}
|
||||
if password != "pass1" {
|
||||
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReAuthCredentialsListener(t *testing.T) {
|
||||
t.Run("successful re-auth", func(t *testing.T) {
|
||||
var reAuthCalled bool
|
||||
var onErrCalled bool
|
||||
var receivedCreds Credentials
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
reAuthCalled = true
|
||||
receivedCreds = creds
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
},
|
||||
)
|
||||
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
listener.OnNext(creds)
|
||||
|
||||
if !reAuthCalled {
|
||||
t.Fatal("expected reAuth to be called")
|
||||
}
|
||||
if onErrCalled {
|
||||
t.Fatal("expected onErr not to be called")
|
||||
}
|
||||
if receivedCreds != creds {
|
||||
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-auth error", func(t *testing.T) {
|
||||
var reAuthCalled bool
|
||||
var onErrCalled bool
|
||||
var receivedErr error
|
||||
expectedErr := errors.New("re-auth failed")
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
reAuthCalled = true
|
||||
return expectedErr
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
receivedErr = err
|
||||
},
|
||||
)
|
||||
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
listener.OnNext(creds)
|
||||
|
||||
if !reAuthCalled {
|
||||
t.Fatal("expected reAuth to be called")
|
||||
}
|
||||
if !onErrCalled {
|
||||
t.Fatal("expected onErr to be called")
|
||||
}
|
||||
if receivedErr != expectedErr {
|
||||
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on error", func(t *testing.T) {
|
||||
var onErrCalled bool
|
||||
var receivedErr error
|
||||
expectedErr := errors.New("provider error")
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
receivedErr = err
|
||||
},
|
||||
)
|
||||
|
||||
listener.OnError(expectedErr)
|
||||
|
||||
if !onErrCalled {
|
||||
t.Fatal("expected onErr to be called")
|
||||
}
|
||||
if receivedErr != expectedErr {
|
||||
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil callbacks", func(t *testing.T) {
|
||||
listener := NewReAuthCredentialsListener(nil, nil)
|
||||
|
||||
// Should not panic
|
||||
listener.OnNext(NewBasicCredentials("user1", "pass1"))
|
||||
listener.OnError(errors.New("test error"))
|
||||
})
|
||||
}
|
47
auth/reauth_credentials_listener.go
Normal file
47
auth/reauth_credentials_listener.go
Normal file
@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
|
||||
// It is used to re-authenticate the credentials when they are updated.
|
||||
// It contains:
|
||||
// - reAuth: a function that takes the new credentials and returns an error if any.
|
||||
// - onErr: a function that takes an error and handles it.
|
||||
type ReAuthCredentialsListener struct {
|
||||
reAuth func(credentials Credentials) error
|
||||
onErr func(err error)
|
||||
}
|
||||
|
||||
// OnNext is called when the credentials are updated.
|
||||
// It calls the reAuth function with the new credentials.
|
||||
// If the reAuth function returns an error, it calls the onErr function with the error.
|
||||
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
|
||||
if c.reAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.reAuth(credentials)
|
||||
if err != nil {
|
||||
c.OnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OnError is called when an error occurs.
|
||||
// It can be called from both the credentials provider and the reAuth function.
|
||||
func (c *ReAuthCredentialsListener) OnError(err error) {
|
||||
if c.onErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.onErr(err)
|
||||
}
|
||||
|
||||
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
|
||||
// Implements the auth.CredentialsListener interface.
|
||||
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
|
||||
return &ReAuthCredentialsListener{
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)
|
Reference in New Issue
Block a user