mirror of
https://github.com/redis/go-redis.git
synced 2025-07-28 06:42:00 +03:00
add tests
This commit is contained in:
169
redis_test.go
169
redis_test.go
@ -14,6 +14,7 @@ import (
|
||||
. "github.com/bsm/gomega"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
||||
type redisHookError struct{}
|
||||
@ -727,3 +728,171 @@ var _ = Describe("Dialer connection timeouts", func() {
|
||||
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Credentials Provider Priority", func() {
|
||||
var client *redis.Client
|
||||
var opt *redis.Options
|
||||
var recorder *commandRecorder
|
||||
|
||||
BeforeEach(func() {
|
||||
recorder = newCommandRecorder(10)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if client != nil {
|
||||
Expect(client.Close()).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
It("should use streaming provider when available", func() {
|
||||
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
|
||||
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
username, password := providerCreds.BasicAuth()
|
||||
return username, password
|
||||
},
|
||||
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||
username, password := ctxCreds.BasicAuth()
|
||||
return username, password, nil
|
||||
},
|
||||
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||
credentials: streamingCreds,
|
||||
updates: make(chan auth.Credentials, 1),
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use context provider when streaming provider is not available", func() {
|
||||
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
|
||||
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
username, password := providerCreds.BasicAuth()
|
||||
return username, password
|
||||
},
|
||||
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
|
||||
username, password := ctxCreds.BasicAuth()
|
||||
return username, password, nil
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use regular provider when streaming and context providers are not available", func() {
|
||||
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
CredentialsProvider: func() (string, string) {
|
||||
username, password := providerCreds.BasicAuth()
|
||||
return username, password
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use username/password fields when no providers are set", func() {
|
||||
opt = &redis.Options{
|
||||
Username: "field_user",
|
||||
Password: "field_pass",
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should use empty credentials when nothing is set", func() {
|
||||
opt = &redis.Options{}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// no pass, ok
|
||||
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH")).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should handle credential updates from streaming provider", func() {
|
||||
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
|
||||
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
|
||||
|
||||
opt = &redis.Options{
|
||||
StreamingCredentialsProvider: &mockStreamingProvider{
|
||||
credentials: initialCreds,
|
||||
updates: make(chan auth.Credentials, 1),
|
||||
},
|
||||
}
|
||||
|
||||
client = redis.NewClient(opt)
|
||||
client.AddHook(recorder.Hook())
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
|
||||
|
||||
// Update credentials
|
||||
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
|
||||
// wrongpass
|
||||
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
|
||||
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
type mockStreamingProvider struct {
|
||||
credentials auth.Credentials
|
||||
err error
|
||||
updates chan auth.Credentials
|
||||
}
|
||||
|
||||
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
|
||||
if m.err != nil {
|
||||
return nil, nil, m.err
|
||||
}
|
||||
|
||||
// Send initial credentials
|
||||
listener.OnNext(m.credentials)
|
||||
|
||||
// Start goroutine to handle updates
|
||||
go func() {
|
||||
for creds := range m.updates {
|
||||
listener.OnNext(creds)
|
||||
}
|
||||
}()
|
||||
|
||||
return m.credentials, func() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// this is just a mock:
|
||||
// allow multiple closes from multiple listeners
|
||||
}
|
||||
}()
|
||||
close(m.updates)
|
||||
return
|
||||
}, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user