mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)
* wip * update documentation * add streamingcredentialsprovider in options * fix: put back option in pool creation * add package level comment * Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested * Change function type name Change CancelProviderFunc to UnsubscribeFunc * add tests * fix race in tests * fix example tests * wip, hooks refactor * fix build * update README.md * update wordlist * update README.md * refactor(auth): early returns in cred listener * fix(doctest): simulate some delay * feat(conn): add close hook on conn * fix(tests): simulate start/stop in mock credentials provider * fix(auth): don't double close the conn * docs(README): mark streaming credentials provider as experimental * fix(auth): streamline auth err proccess * fix(auth): check err on close conn * chore(entraid): use the repo under redis org
This commit is contained in:
308
auth/auth_test.go
Normal file
308
auth/auth_test.go
Normal file
@ -0,0 +1,308 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockStreamingProvider struct {
|
||||
credentials Credentials
|
||||
err error
|
||||
updates chan Credentials
|
||||
}
|
||||
|
||||
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
|
||||
return &mockStreamingProvider{
|
||||
credentials: initialCreds,
|
||||
updates: make(chan Credentials, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
|
||||
if m.err != nil {
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
// Send initial credentials
|
||||
listener.OnNext(m.credentials)
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}()
|
||||
|
||||
return m.credentials, func() error {
|
||||
close(m.updates)
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestStreamingCredentialsProvider(t *testing.T) {
|
||||
t.Run("successful subscription", func(t *testing.T) {
|
||||
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||
provider := newMockStreamingProvider(initialCreds)
|
||||
|
||||
var receivedCreds []Credentials
|
||||
var receivedErrors []error
|
||||
var mu sync.Mutex
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
mu.Lock()
|
||||
receivedCreds = append(receivedCreds, creds)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cancel == nil {
|
||||
t.Fatal("expected cancel function to be non-nil")
|
||||
}
|
||||
if creds != initialCreds {
|
||||
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||
}
|
||||
if len(receivedCreds) != 1 {
|
||||
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
|
||||
}
|
||||
if receivedCreds[0] != initialCreds {
|
||||
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
|
||||
}
|
||||
if len(receivedErrors) != 0 {
|
||||
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||
}
|
||||
|
||||
// Send an update
|
||||
newCreds := NewBasicCredentials("user2", "pass2")
|
||||
provider.updates <- newCreds
|
||||
|
||||
// Wait for update to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
if len(receivedCreds) != 2 {
|
||||
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
|
||||
}
|
||||
if receivedCreds[1] != newCreds {
|
||||
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Cancel subscription
|
||||
if err := cancel(); err != nil {
|
||||
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subscription error", func(t *testing.T) {
|
||||
provider := &mockStreamingProvider{
|
||||
err: errors.New("subscription failed"),
|
||||
}
|
||||
|
||||
var receivedCreds []Credentials
|
||||
var receivedErrors []error
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
receivedCreds = append(receivedCreds, creds)
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if cancel != nil {
|
||||
t.Fatal("expected cancel function to be nil")
|
||||
}
|
||||
if creds != nil {
|
||||
t.Fatalf("expected nil credentials, got %v", creds)
|
||||
}
|
||||
if len(receivedCreds) != 0 {
|
||||
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
|
||||
}
|
||||
if len(receivedErrors) != 0 {
|
||||
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-auth error", func(t *testing.T) {
|
||||
initialCreds := NewBasicCredentials("user1", "pass1")
|
||||
provider := newMockStreamingProvider(initialCreds)
|
||||
|
||||
reauthErr := errors.New("re-auth failed")
|
||||
var receivedErrors []error
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
return reauthErr
|
||||
},
|
||||
func(err error) {
|
||||
receivedErrors = append(receivedErrors, err)
|
||||
},
|
||||
)
|
||||
|
||||
creds, cancel, err := provider.Subscribe(listener)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cancel == nil {
|
||||
t.Fatal("expected cancel function to be non-nil")
|
||||
}
|
||||
if creds != initialCreds {
|
||||
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
||||
}
|
||||
if len(receivedErrors) != 1 {
|
||||
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
|
||||
}
|
||||
if receivedErrors[0] != reauthErr {
|
||||
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
|
||||
}
|
||||
|
||||
if err := cancel(); err != nil {
|
||||
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBasicCredentials(t *testing.T) {
|
||||
t.Run("basic auth", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
username, password := creds.BasicAuth()
|
||||
if username != "user1" {
|
||||
t.Fatalf("expected username 'user1', got '%s'", username)
|
||||
}
|
||||
if password != "pass1" {
|
||||
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raw credentials", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
raw := creds.RawCredentials()
|
||||
expected := "user1:pass1"
|
||||
if raw != expected {
|
||||
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty username", func(t *testing.T) {
|
||||
creds := NewBasicCredentials("", "pass1")
|
||||
username, password := creds.BasicAuth()
|
||||
if username != "" {
|
||||
t.Fatalf("expected empty username, got '%s'", username)
|
||||
}
|
||||
if password != "pass1" {
|
||||
t.Fatalf("expected password 'pass1', got '%s'", password)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReAuthCredentialsListener(t *testing.T) {
|
||||
t.Run("successful re-auth", func(t *testing.T) {
|
||||
var reAuthCalled bool
|
||||
var onErrCalled bool
|
||||
var receivedCreds Credentials
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
reAuthCalled = true
|
||||
receivedCreds = creds
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
},
|
||||
)
|
||||
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
listener.OnNext(creds)
|
||||
|
||||
if !reAuthCalled {
|
||||
t.Fatal("expected reAuth to be called")
|
||||
}
|
||||
if onErrCalled {
|
||||
t.Fatal("expected onErr not to be called")
|
||||
}
|
||||
if receivedCreds != creds {
|
||||
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-auth error", func(t *testing.T) {
|
||||
var reAuthCalled bool
|
||||
var onErrCalled bool
|
||||
var receivedErr error
|
||||
expectedErr := errors.New("re-auth failed")
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
reAuthCalled = true
|
||||
return expectedErr
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
receivedErr = err
|
||||
},
|
||||
)
|
||||
|
||||
creds := NewBasicCredentials("user1", "pass1")
|
||||
listener.OnNext(creds)
|
||||
|
||||
if !reAuthCalled {
|
||||
t.Fatal("expected reAuth to be called")
|
||||
}
|
||||
if !onErrCalled {
|
||||
t.Fatal("expected onErr to be called")
|
||||
}
|
||||
if receivedErr != expectedErr {
|
||||
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on error", func(t *testing.T) {
|
||||
var onErrCalled bool
|
||||
var receivedErr error
|
||||
expectedErr := errors.New("provider error")
|
||||
|
||||
listener := NewReAuthCredentialsListener(
|
||||
func(creds Credentials) error {
|
||||
return nil
|
||||
},
|
||||
func(err error) {
|
||||
onErrCalled = true
|
||||
receivedErr = err
|
||||
},
|
||||
)
|
||||
|
||||
listener.OnError(expectedErr)
|
||||
|
||||
if !onErrCalled {
|
||||
t.Fatal("expected onErr to be called")
|
||||
}
|
||||
if receivedErr != expectedErr {
|
||||
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil callbacks", func(t *testing.T) {
|
||||
listener := NewReAuthCredentialsListener(nil, nil)
|
||||
|
||||
// Should not panic
|
||||
listener.OnNext(NewBasicCredentials("user1", "pass1"))
|
||||
listener.OnError(errors.New("test error"))
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user