mirror of
https://github.com/redis/go-redis.git
synced 2025-10-21 20:53:41 +03:00
85 lines
1.9 KiB
Go
85 lines
1.9 KiB
Go
package streaming
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9/internal/pool"
|
|
)
|
|
|
|
type ReAuthPoolHook struct {
|
|
// conn id -> func() reauth func with error handling
|
|
shouldReAuth map[uint64]func(error)
|
|
lock sync.RWMutex
|
|
workers chan struct{}
|
|
reAuthTimeout time.Duration
|
|
}
|
|
|
|
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
|
return &ReAuthPoolHook{
|
|
shouldReAuth: make(map[uint64]func(error)),
|
|
workers: make(chan struct{}, poolSize),
|
|
reAuthTimeout: reAuthTimeout,
|
|
}
|
|
|
|
}
|
|
|
|
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
|
r.lock.Lock()
|
|
defer r.lock.Unlock()
|
|
r.shouldReAuth[connID] = reAuthFn
|
|
}
|
|
|
|
func (r *ReAuthPoolHook) ShouldReAuth(connID uint64) bool {
|
|
r.lock.RLock()
|
|
defer r.lock.RUnlock()
|
|
_, ok := r.shouldReAuth[connID]
|
|
return ok
|
|
}
|
|
|
|
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
|
|
r.lock.Lock()
|
|
defer r.lock.Unlock()
|
|
delete(r.shouldReAuth, connID)
|
|
}
|
|
|
|
func (r *ReAuthPoolHook) OnGet(_ context.Context, _ *pool.Conn, _ bool) error {
|
|
// noop
|
|
return nil
|
|
}
|
|
|
|
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
|
if reAuthFn, ok := r.shouldReAuth[conn.GetID()]; ok {
|
|
go func() {
|
|
<-r.workers
|
|
var err error
|
|
timeout := time.After(r.reAuthTimeout)
|
|
for !conn.Usable.CompareAndSwap(true, false) {
|
|
select {
|
|
case <-timeout:
|
|
err = pool.ErrConnUnusableTimeout
|
|
default:
|
|
time.Sleep(time.Millisecond)
|
|
// connection closed, cannot re-authenticate
|
|
}
|
|
}
|
|
|
|
reAuthFn(err)
|
|
|
|
conn.Usable.Store(true)
|
|
r.workers <- struct{}{}
|
|
}()
|
|
}
|
|
|
|
// the reauth will happen in background, as far as the pool is concerned:
|
|
// pool the connection, don't remove it, no error
|
|
return true, false, nil
|
|
}
|
|
|
|
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
|
r.ClearReAuthMark(conn.GetID())
|
|
}
|
|
|
|
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)
|