1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-28 06:42:00 +03:00

feat: Introducing StreamingCredentialsProvider for token based authentication (#3320)

* wip

* update documentation

* add streamingcredentialsprovider in options

* fix: put back option in pool creation

* add package level comment

* Initial re authentication implementation

Introduces the StreamingCredentialsProvider as the CredentialsProvider
with the highest priority.

TODO: needs to be tested

* Change function type name

Change CancelProviderFunc to UnsubscribeFunc

* add tests

* fix race in tests

* fix example tests

* wip, hooks refactor

* fix build

* update README.md

* update wordlist

* update README.md

* refactor(auth): early returns in cred listener

* fix(doctest): simulate some delay

* feat(conn): add close hook on conn

* fix(tests): simulate start/stop in mock credentials provider

* fix(auth): don't double close the conn

* docs(README): mark streaming credentials provider as experimental

* fix(auth): streamline auth err proccess

* fix(auth): check err on close conn

* chore(entraid): use the repo under redis org
This commit is contained in:
Nedyalko Dyakov
2025-05-27 16:25:20 +03:00
committed by GitHub
parent 28a3c97409
commit 86d418f940
20 changed files with 1103 additions and 130 deletions

View File

@ -14,6 +14,7 @@ import (
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/auth"
)
type redisHookError struct{}
@ -727,6 +728,174 @@ var _ = Describe("Dialer connection timeouts", func() {
Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
})
})
var _ = Describe("Credentials Provider Priority", func() {
var client *redis.Client
var opt *redis.Options
var recorder *commandRecorder
BeforeEach(func() {
recorder = newCommandRecorder(10)
})
AfterEach(func() {
if client != nil {
Expect(client.Close()).NotTo(HaveOccurred())
}
})
It("should use streaming provider when available", func() {
streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: streamingCreds,
updates: make(chan auth.Credentials, 1),
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
})
It("should use context provider when streaming provider is not available", func() {
ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
username, password := ctxCreds.BasicAuth()
return username, password, nil
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
})
It("should use regular provider when streaming and context providers are not available", func() {
providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
CredentialsProvider: func() (string, string) {
username, password := providerCreds.BasicAuth()
return username, password
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
})
It("should use username/password fields when no providers are set", func() {
opt = &redis.Options{
Username: "field_user",
Password: "field_pass",
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
})
It("should use empty credentials when nothing is set", func() {
opt = &redis.Options{}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// no pass, ok
Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
Expect(recorder.Contains("AUTH")).To(BeFalse())
})
It("should handle credential updates from streaming provider", func() {
initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
updatesChan := make(chan auth.Credentials, 1)
opt = &redis.Options{
StreamingCredentialsProvider: &mockStreamingProvider{
credentials: initialCreds,
updates: updatesChan,
},
}
client = redis.NewClient(opt)
client.AddHook(recorder.Hook())
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
// Update credentials
opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
// wrongpass
Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
close(updatesChan)
})
})
type mockStreamingProvider struct {
credentials auth.Credentials
err error
updates chan auth.Credentials
}
func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
if m.err != nil {
return nil, nil, m.err
}
// Start goroutine to handle updates
go func() {
for creds := range m.updates {
m.credentials = creds
listener.OnNext(creds)
}
}()
return m.credentials, func() (err error) {
defer func() {
if r := recover(); r != nil {
// this is just a mock:
// allow multiple closes from multiple listeners
}
}()
return
}, nil
}
var _ = Describe("Client creation", func() {
Context("simple client with nil options", func() {
It("panics", func() {