1
0
mirror of https://github.com/redis/go-redis.git synced 2025-09-02 22:01:16 +03:00

example circuit breaker implementation, fast fail on big pools

This commit is contained in:
Nedyalko Dyakov
2025-09-02 18:45:28 +03:00
parent b34f8270c6
commit 32ddb96a6b
6 changed files with 650 additions and 11 deletions

269
hitless/circuit_breaker.go Normal file
View File

@@ -0,0 +1,269 @@
package hitless
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
)
// CircuitBreakerState represents the state of a circuit breaker
type CircuitBreakerState int32
const (
// CircuitBreakerClosed - normal operation, requests allowed
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen - failing fast, requests rejected
CircuitBreakerOpen
// CircuitBreakerHalfOpen - testing if service recovered
CircuitBreakerHalfOpen
)
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
type CircuitBreaker struct {
// Configuration
failureThreshold int // Number of failures before opening
resetTimeout time.Duration // How long to stay open before testing
maxRequests int // Max requests allowed in half-open state
// State tracking (atomic for lock-free access)
state atomic.Int32 // CircuitBreakerState
failures atomic.Int64 // Current failure count
successes atomic.Int64 // Success count in half-open state
requests atomic.Int64 // Request count in half-open state
lastFailureTime atomic.Int64 // Unix timestamp of last failure
lastSuccessTime atomic.Int64 // Unix timestamp of last success
// Endpoint identification
endpoint string
config *Config
}
// newCircuitBreaker creates a new circuit breaker for an endpoint
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
// Use sensible defaults if not configured
failureThreshold := 10
resetTimeout := 500 * time.Millisecond
maxRequests := 10
// These could be added to Config in the future without breaking API
// For now, use internal defaults that work well
return &CircuitBreaker{
failureThreshold: failureThreshold,
resetTimeout: resetTimeout,
maxRequests: maxRequests,
endpoint: endpoint,
config: config,
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
}
}
// IsOpen returns true if the circuit breaker is open (rejecting requests)
func (cb *CircuitBreaker) IsOpen() bool {
state := CircuitBreakerState(cb.state.Load())
if state == CircuitBreakerOpen {
// Check if we should transition to half-open
if cb.shouldAttemptReset() {
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.requests.Store(0)
cb.successes.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker for %s transitioning to half-open", cb.endpoint)
}
return false // Now in half-open state, allow requests
}
}
return true // Still open
}
return false
}
// shouldAttemptReset checks if enough time has passed to attempt reset
func (cb *CircuitBreaker) shouldAttemptReset() bool {
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
return time.Since(lastFailure) >= cb.resetTimeout
}
// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
// Fast path: if circuit is open, fail immediately
if cb.IsOpen() {
return ErrCircuitBreakerOpen
}
state := CircuitBreakerState(cb.state.Load())
// In half-open state, limit the number of requests
if state == CircuitBreakerHalfOpen {
requests := cb.requests.Add(1)
if requests > int64(cb.maxRequests) {
cb.requests.Add(-1) // Revert the increment
return ErrCircuitBreakerOpen
}
}
// Execute the function
err := fn()
if err != nil {
cb.recordFailure()
return err
}
cb.recordSuccess()
return nil
}
// recordFailure records a failure and potentially opens the circuit
func (cb *CircuitBreaker) recordFailure() {
cb.lastFailureTime.Store(time.Now().Unix())
failures := cb.failures.Add(1)
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
if failures >= int64(cb.failureThreshold) {
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker opened for endpoint %s after %d failures",
cb.endpoint, failures)
}
}
}
case CircuitBreakerHalfOpen:
// Any failure in half-open state immediately opens the circuit
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
if cb.config != nil && cb.config.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker reopened for endpoint %s due to failure in half-open state",
cb.endpoint)
}
}
}
}
// recordSuccess records a success and potentially closes the circuit
func (cb *CircuitBreaker) recordSuccess() {
cb.lastSuccessTime.Store(time.Now().Unix())
state := CircuitBreakerState(cb.state.Load())
if state == CircuitBreakerClosed {
// Reset failure count on success in closed state
cb.failures.Store(0)
} else if state == CircuitBreakerHalfOpen {
successes := cb.successes.Add(1)
// If we've had enough successful requests, close the circuit
if successes >= int64(cb.maxRequests) {
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.failures.Store(0)
if cb.config != nil && cb.config.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(),
"hitless: circuit breaker closed for endpoint %s after %d successful requests",
cb.endpoint, successes)
}
}
}
}
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(cb.state.Load())
}
// GetStats returns current statistics for monitoring
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
return CircuitBreakerStats{
Endpoint: cb.endpoint,
State: cb.GetState(),
Failures: cb.failures.Load(),
Successes: cb.successes.Load(),
Requests: cb.requests.Load(),
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
}
}
// CircuitBreakerStats provides statistics about a circuit breaker
type CircuitBreakerStats struct {
Endpoint string
State CircuitBreakerState
Failures int64
Successes int64
Requests int64
LastFailureTime time.Time
LastSuccessTime time.Time
}
// CircuitBreakerManager manages circuit breakers for multiple endpoints
type CircuitBreakerManager struct {
breakers sync.Map // map[string]*CircuitBreaker
config *Config
}
// newCircuitBreakerManager creates a new circuit breaker manager
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
return &CircuitBreakerManager{
config: config,
}
}
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
if breaker, ok := cbm.breakers.Load(endpoint); ok {
return breaker.(*CircuitBreaker)
}
// Create new circuit breaker
newBreaker := newCircuitBreaker(endpoint, cbm.config)
actual, _ := cbm.breakers.LoadOrStore(endpoint, newBreaker)
return actual.(*CircuitBreaker)
}
// GetAllStats returns statistics for all circuit breakers
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
var stats []CircuitBreakerStats
cbm.breakers.Range(func(key, value interface{}) bool {
breaker := value.(*CircuitBreaker)
stats = append(stats, breaker.GetStats())
return true
})
return stats
}
// Reset resets all circuit breakers (useful for testing)
func (cbm *CircuitBreakerManager) Reset() {
cbm.breakers.Range(func(key, value interface{}) bool {
breaker := value.(*CircuitBreaker)
breaker.state.Store(int32(CircuitBreakerClosed))
breaker.failures.Store(0)
breaker.successes.Store(0)
breaker.requests.Store(0)
breaker.lastFailureTime.Store(0)
breaker.lastSuccessTime.Store(0)
return true
})
}

View File

@@ -0,0 +1,292 @@
package hitless
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9/logging"
)
func TestCircuitBreaker(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError, // Reduce noise in tests
}
t.Run("InitialState", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
if cb.IsOpen() {
t.Error("Circuit breaker should start in closed state")
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("SuccessfulExecution", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("FailureThreshold", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Fail 4 times (below threshold of 5)
for i := 0; i < 4; i++ {
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Circuit should still be closed after %d failures", i+1)
}
}
// 5th failure should open the circuit
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("OpenCircuitFailsFast", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Now it should fail fast
err := cb.Execute(func() error {
t.Error("Function should not be called when circuit is open")
return nil
})
if err != ErrCircuitBreakerOpen {
t.Errorf("Expected ErrCircuitBreakerOpen, got %v", err)
}
})
t.Run("HalfOpenTransition", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
cb.resetTimeout = 100 * time.Millisecond // Short timeout for testing
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Wait for reset timeout
time.Sleep(150 * time.Millisecond)
// Next call should transition to half-open
executed := false
err := cb.Execute(func() error {
executed = true
return nil // Success
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !executed {
t.Error("Function should have been executed in half-open state")
}
})
t.Run("HalfOpenToClosedTransition", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
cb.resetTimeout = 50 * time.Millisecond
cb.maxRequests = 3
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// Execute successful requests in half-open state
for i := 0; i < 3; i++ {
err := cb.Execute(func() error {
return nil // Success
})
if err != nil {
t.Errorf("Expected no error on attempt %d, got %v", i+1, err)
}
}
// Circuit should now be closed
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, cb.GetState())
}
})
t.Run("HalfOpenToOpenOnFailure", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
cb.resetTimeout = 50 * time.Millisecond
testError := errors.New("test error")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
// Wait for reset timeout
time.Sleep(100 * time.Millisecond)
// First request in half-open state fails
err := cb.Execute(func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
// Circuit should be open again
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state %v, got %v", CircuitBreakerOpen, cb.GetState())
}
})
t.Run("Stats", func(t *testing.T) {
cb := newCircuitBreaker("test-endpoint:6379", config)
testError := errors.New("test error")
// Execute some operations
cb.Execute(func() error { return testError }) // Failure
cb.Execute(func() error { return testError }) // Failure
stats := cb.GetStats()
if stats.Endpoint != "test-endpoint:6379" {
t.Errorf("Expected endpoint 'test-endpoint:6379', got %s", stats.Endpoint)
}
if stats.Failures != 2 {
t.Errorf("Expected 2 failures, got %d", stats.Failures)
}
if stats.State != CircuitBreakerClosed {
t.Errorf("Expected state %v, got %v", CircuitBreakerClosed, stats.State)
}
// Test that success resets failure count
cb.Execute(func() error { return nil }) // Success
stats = cb.GetStats()
if stats.Failures != 0 {
t.Errorf("Expected 0 failures after success, got %d", stats.Failures)
}
})
}
func TestCircuitBreakerManager(t *testing.T) {
config := &Config{
LogLevel: logging.LogLevelError,
}
t.Run("GetCircuitBreaker", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
cb3 := manager.GetCircuitBreaker("endpoint1:6379") // Same as cb1
if cb1 == cb2 {
t.Error("Different endpoints should have different circuit breakers")
}
if cb1 != cb3 {
t.Error("Same endpoint should return the same circuit breaker")
}
})
t.Run("GetAllStats", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
// Create circuit breakers for different endpoints
cb1 := manager.GetCircuitBreaker("endpoint1:6379")
cb2 := manager.GetCircuitBreaker("endpoint2:6379")
// Execute some operations
cb1.Execute(func() error { return nil })
cb2.Execute(func() error { return errors.New("test error") })
stats := manager.GetAllStats()
if len(stats) != 2 {
t.Errorf("Expected 2 circuit breaker stats, got %d", len(stats))
}
// Check that we have stats for both endpoints
endpoints := make(map[string]bool)
for _, stat := range stats {
endpoints[stat.Endpoint] = true
}
if !endpoints["endpoint1:6379"] || !endpoints["endpoint2:6379"] {
t.Error("Missing stats for expected endpoints")
}
})
t.Run("Reset", func(t *testing.T) {
manager := newCircuitBreakerManager(config)
testError := errors.New("test error")
cb := manager.GetCircuitBreaker("test-endpoint:6379")
// Force circuit to open
for i := 0; i < 5; i++ {
cb.Execute(func() error { return testError })
}
if cb.GetState() != CircuitBreakerOpen {
t.Error("Circuit should be open")
}
// Reset all circuit breakers
manager.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Error("Circuit should be closed after reset")
}
if cb.failures.Load() != 0 {
t.Error("Failure count should be reset to 0")
}
})
}

View File

@@ -48,3 +48,8 @@ var (
var (
ErrShutdown = errors.New("hitless: shutdown")
)
// circuit breaker errors
var (
ErrCircuitBreakerOpen = errors.New("hitless: circuit breaker is open, failing fast")
)

View File

@@ -2,6 +2,7 @@ package hitless
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
@@ -78,3 +79,22 @@ func (mh *MetricsHook) GetMetrics() map[string]interface{} {
"error_counts": mh.ErrorCounts,
}
}
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
// Get circuit breaker statistics
stats := poolHook.GetCircuitBreakerStats()
for _, stat := range stats {
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
fmt.Printf(" State: %s\n", stat.State)
fmt.Printf(" Failures: %d\n", stat.Failures)
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
// Alert if circuit breaker is open
if stat.State.String() == "open" {
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
}
}
}

View File

@@ -33,18 +33,22 @@ type handoffWorkerManager struct {
// Pool hook reference for handoff processing
poolHook *PoolHook
// Circuit breaker manager for endpoint failure handling
circuitBreakerManager *CircuitBreakerManager
}
// newHandoffWorkerManager creates a new handoff worker manager
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
return &handoffWorkerManager{
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
circuitBreakerManager: newCircuitBreakerManager(config),
}
}
@@ -68,6 +72,16 @@ func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
return hwm.handoffQueue
}
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
return hwm.circuitBreakerManager.GetAllStats()
}
// resetCircuitBreakers resets all circuit breakers (useful for testing)
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
hwm.circuitBreakerManager.Reset()
}
// isHandoffPending returns true if the given connection has a pending handoff
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
_, pending := hwm.pending.Load(conn.GetID())
@@ -286,8 +300,37 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
return false, ErrConnectionInvalidHandoffState
}
// Use circuit breaker to protect against failing endpoints
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
// Check if circuit breaker is open before attempting handoff
if circuitBreaker.IsOpen() {
internal.Logger.Printf(ctx, "hitless: conn[%d] handoff to %s failed fast due to circuit breaker", connID, newEndpoint)
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
}
// Perform the handoff
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
// Update circuit breaker based on result
if err != nil {
// Only track dial/network errors in circuit breaker, not initialization errors
if shouldRetry {
circuitBreaker.recordFailure()
}
return shouldRetry, err
}
// Success - record in circuit breaker
circuitBreaker.recordSuccess()
return false, nil
}
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", conn.GetID(), retries, newEndpoint, conn.RemoteAddr().String())
internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", connID, retries, newEndpoint, conn.RemoteAddr().String())
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
@@ -297,7 +340,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
if hwm.config != nil && hwm.config.LogLevel.WarnOrAbove() { // Warning level
internal.Logger.Printf(ctx,
"hitless: reached max retries (%d) for handoff of conn[%d] to %s",
maxRetries, conn.GetID(), conn.GetHandoffEndpoint())
maxRetries, connID, newEndpoint)
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
@@ -309,7 +352,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", conn.GetID(), newEndpoint, err)
internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", connID, newEndpoint, err)
// hitless: will retry
// Maybe a network error - retry after a delay
return true, err
@@ -349,7 +392,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c
}()
conn.ClearHandoffState()
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint)
internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", connID, newEndpoint)
return false, nil
}

View File

@@ -105,6 +105,16 @@ func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
return ph.workerManager.getHandoffQueue()
}
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
return ph.workerManager.getCircuitBreakerStats()
}
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
func (ph *PoolHook) ResetCircuitBreakers() {
ph.workerManager.resetCircuitBreakers()
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error {
// NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is