mirror of
https://github.com/redis/go-redis.git
synced 2025-06-06 17:40:59 +03:00
* wip * update documentation * add streamingcredentialsprovider in options * fix: put back option in pool creation * add package level comment * Initial re authentication implementation Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested * Change function type name Change CancelProviderFunc to UnsubscribeFunc * add tests * fix race in tests * fix example tests * wip, hooks refactor * fix build * update README.md * update wordlist * update README.md * refactor(auth): early returns in cred listener * fix(doctest): simulate some delay * feat(conn): add close hook on conn * fix(tests): simulate start/stop in mock credentials provider * fix(auth): don't double close the conn * docs(README): mark streaming credentials provider as experimental * fix(auth): streamline auth err proccess * fix(auth): check err on close conn * chore(entraid): use the repo under redis org
309 lines
7.4 KiB
Go
309 lines
7.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type mockStreamingProvider struct {
|
|
credentials Credentials
|
|
err error
|
|
updates chan Credentials
|
|
}
|
|
|
|
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
|
|
return &mockStreamingProvider{
|
|
credentials: initialCreds,
|
|
updates: make(chan Credentials, 10),
|
|
}
|
|
}
|
|
|
|
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
|
|
if m.err != nil {
|
|
return nil, nil, m.err
|
|
}
|
|
|
|
// Send initial credentials
|
|
listener.OnNext(m.credentials)
|
|
|
|
// Start goroutine to handle updates
|
|
go func() {
|
|
for creds := range m.updates {
|
|
listener.OnNext(creds)
|
|
}
|
|
}()
|
|
|
|
return m.credentials, func() error {
|
|
close(m.updates)
|
|
return nil
|
|
}, nil
|
|
}
|
|
|
|
func TestStreamingCredentialsProvider(t *testing.T) {
|
|
t.Run("successful subscription", func(t *testing.T) {
|
|
initialCreds := NewBasicCredentials("user1", "pass1")
|
|
provider := newMockStreamingProvider(initialCreds)
|
|
|
|
var receivedCreds []Credentials
|
|
var receivedErrors []error
|
|
var mu sync.Mutex
|
|
|
|
listener := NewReAuthCredentialsListener(
|
|
func(creds Credentials) error {
|
|
mu.Lock()
|
|
receivedCreds = append(receivedCreds, creds)
|
|
mu.Unlock()
|
|
return nil
|
|
},
|
|
func(err error) {
|
|
receivedErrors = append(receivedErrors, err)
|
|
},
|
|
)
|
|
|
|
creds, cancel, err := provider.Subscribe(listener)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if cancel == nil {
|
|
t.Fatal("expected cancel function to be non-nil")
|
|
}
|
|
if creds != initialCreds {
|
|
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
|
}
|
|
if len(receivedCreds) != 1 {
|
|
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
|
|
}
|
|
if receivedCreds[0] != initialCreds {
|
|
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
|
|
}
|
|
if len(receivedErrors) != 0 {
|
|
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
|
}
|
|
|
|
// Send an update
|
|
newCreds := NewBasicCredentials("user2", "pass2")
|
|
provider.updates <- newCreds
|
|
|
|
// Wait for update to be processed
|
|
time.Sleep(100 * time.Millisecond)
|
|
mu.Lock()
|
|
if len(receivedCreds) != 2 {
|
|
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
|
|
}
|
|
if receivedCreds[1] != newCreds {
|
|
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
|
|
}
|
|
mu.Unlock()
|
|
|
|
// Cancel subscription
|
|
if err := cancel(); err != nil {
|
|
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("subscription error", func(t *testing.T) {
|
|
provider := &mockStreamingProvider{
|
|
err: errors.New("subscription failed"),
|
|
}
|
|
|
|
var receivedCreds []Credentials
|
|
var receivedErrors []error
|
|
|
|
listener := NewReAuthCredentialsListener(
|
|
func(creds Credentials) error {
|
|
receivedCreds = append(receivedCreds, creds)
|
|
return nil
|
|
},
|
|
func(err error) {
|
|
receivedErrors = append(receivedErrors, err)
|
|
},
|
|
)
|
|
|
|
creds, cancel, err := provider.Subscribe(listener)
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if cancel != nil {
|
|
t.Fatal("expected cancel function to be nil")
|
|
}
|
|
if creds != nil {
|
|
t.Fatalf("expected nil credentials, got %v", creds)
|
|
}
|
|
if len(receivedCreds) != 0 {
|
|
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
|
|
}
|
|
if len(receivedErrors) != 0 {
|
|
t.Fatalf("expected no errors, got %d", len(receivedErrors))
|
|
}
|
|
})
|
|
|
|
t.Run("re-auth error", func(t *testing.T) {
|
|
initialCreds := NewBasicCredentials("user1", "pass1")
|
|
provider := newMockStreamingProvider(initialCreds)
|
|
|
|
reauthErr := errors.New("re-auth failed")
|
|
var receivedErrors []error
|
|
|
|
listener := NewReAuthCredentialsListener(
|
|
func(creds Credentials) error {
|
|
return reauthErr
|
|
},
|
|
func(err error) {
|
|
receivedErrors = append(receivedErrors, err)
|
|
},
|
|
)
|
|
|
|
creds, cancel, err := provider.Subscribe(listener)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if cancel == nil {
|
|
t.Fatal("expected cancel function to be non-nil")
|
|
}
|
|
if creds != initialCreds {
|
|
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
|
|
}
|
|
if len(receivedErrors) != 1 {
|
|
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
|
|
}
|
|
if receivedErrors[0] != reauthErr {
|
|
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
|
|
}
|
|
|
|
if err := cancel(); err != nil {
|
|
t.Fatalf("unexpected error cancelling subscription: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBasicCredentials(t *testing.T) {
|
|
t.Run("basic auth", func(t *testing.T) {
|
|
creds := NewBasicCredentials("user1", "pass1")
|
|
username, password := creds.BasicAuth()
|
|
if username != "user1" {
|
|
t.Fatalf("expected username 'user1', got '%s'", username)
|
|
}
|
|
if password != "pass1" {
|
|
t.Fatalf("expected password 'pass1', got '%s'", password)
|
|
}
|
|
})
|
|
|
|
t.Run("raw credentials", func(t *testing.T) {
|
|
creds := NewBasicCredentials("user1", "pass1")
|
|
raw := creds.RawCredentials()
|
|
expected := "user1:pass1"
|
|
if raw != expected {
|
|
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
|
|
}
|
|
})
|
|
|
|
t.Run("empty username", func(t *testing.T) {
|
|
creds := NewBasicCredentials("", "pass1")
|
|
username, password := creds.BasicAuth()
|
|
if username != "" {
|
|
t.Fatalf("expected empty username, got '%s'", username)
|
|
}
|
|
if password != "pass1" {
|
|
t.Fatalf("expected password 'pass1', got '%s'", password)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestReAuthCredentialsListener(t *testing.T) {
|
|
t.Run("successful re-auth", func(t *testing.T) {
|
|
var reAuthCalled bool
|
|
var onErrCalled bool
|
|
var receivedCreds Credentials
|
|
|
|
listener := NewReAuthCredentialsListener(
|
|
func(creds Credentials) error {
|
|
reAuthCalled = true
|
|
receivedCreds = creds
|
|
return nil
|
|
},
|
|
func(err error) {
|
|
onErrCalled = true
|
|
},
|
|
)
|
|
|
|
creds := NewBasicCredentials("user1", "pass1")
|
|
listener.OnNext(creds)
|
|
|
|
if !reAuthCalled {
|
|
t.Fatal("expected reAuth to be called")
|
|
}
|
|
if onErrCalled {
|
|
t.Fatal("expected onErr not to be called")
|
|
}
|
|
if receivedCreds != creds {
|
|
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
|
|
}
|
|
})
|
|
|
|
t.Run("re-auth error", func(t *testing.T) {
|
|
var reAuthCalled bool
|
|
var onErrCalled bool
|
|
var receivedErr error
|
|
expectedErr := errors.New("re-auth failed")
|
|
|
|
listener := NewReAuthCredentialsListener(
|
|
func(creds Credentials) error {
|
|
reAuthCalled = true
|
|
return expectedErr
|
|
},
|
|
func(err error) {
|
|
onErrCalled = true
|
|
receivedErr = err
|
|
},
|
|
)
|
|
|
|
creds := NewBasicCredentials("user1", "pass1")
|
|
listener.OnNext(creds)
|
|
|
|
if !reAuthCalled {
|
|
t.Fatal("expected reAuth to be called")
|
|
}
|
|
if !onErrCalled {
|
|
t.Fatal("expected onErr to be called")
|
|
}
|
|
if receivedErr != expectedErr {
|
|
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
|
}
|
|
})
|
|
|
|
t.Run("on error", func(t *testing.T) {
|
|
var onErrCalled bool
|
|
var receivedErr error
|
|
expectedErr := errors.New("provider error")
|
|
|
|
listener := NewReAuthCredentialsListener(
|
|
func(creds Credentials) error {
|
|
return nil
|
|
},
|
|
func(err error) {
|
|
onErrCalled = true
|
|
receivedErr = err
|
|
},
|
|
)
|
|
|
|
listener.OnError(expectedErr)
|
|
|
|
if !onErrCalled {
|
|
t.Fatal("expected onErr to be called")
|
|
}
|
|
if receivedErr != expectedErr {
|
|
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
|
|
}
|
|
})
|
|
|
|
t.Run("nil callbacks", func(t *testing.T) {
|
|
listener := NewReAuthCredentialsListener(nil, nil)
|
|
|
|
// Should not panic
|
|
listener.OnNext(NewBasicCredentials("user1", "pass1"))
|
|
listener.OnError(errors.New("test error"))
|
|
})
|
|
}
|