diff --git a/internal/pool/conn.go b/internal/pool/conn.go index c1087b40..dbfcca0c 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -25,6 +25,12 @@ type Conn struct { createdAt time.Time onClose func() error + + // Push notification processor for handling push notifications on this connection + PushNotificationProcessor interface { + IsEnabled() bool + ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error + } } func NewConn(netConn net.Conn) *Conn { diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 3ee3dea6..4548a645 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -9,6 +9,7 @@ import ( "time" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/proto" ) var ( @@ -71,6 +72,12 @@ type Options struct { MaxActiveConns int ConnMaxIdleTime time.Duration ConnMaxLifetime time.Duration + + // Push notification processor for connections + PushNotificationProcessor interface { + IsEnabled() bool + ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error + } } type lastDialErrorWrap struct { @@ -228,6 +235,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { cn := NewConn(netConn) cn.pooled = pooled + + // Set push notification processor if available + if p.cfg.PushNotificationProcessor != nil { + cn.PushNotificationProcessor = p.cfg.PushNotificationProcessor + } + return cn, nil } @@ -377,9 +390,24 @@ func (p *ConnPool) popIdle() (*Conn, error) { func (p *ConnPool) Put(ctx context.Context, cn *Conn) { if cn.rd.Buffered() > 0 { - internal.Logger.Printf(ctx, "Conn has unread data") - p.Remove(ctx, cn, BadConnError{}) - return + // Check if this might be push notification data + if cn.PushNotificationProcessor != nil && cn.PushNotificationProcessor.IsEnabled() { + // Try to process pending push notifications before discarding connection + err := cn.PushNotificationProcessor.ProcessPendingNotifications(ctx, cn.rd) + if err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications: %v", err) + } + // Check again if there's still unread data after processing push notifications + if cn.rd.Buffered() > 0 { + internal.Logger.Printf(ctx, "Conn has unread data after processing push notifications") + p.Remove(ctx, cn, BadConnError{}) + return + } + } else { + internal.Logger.Printf(ctx, "Conn has unread data") + p.Remove(ctx, cn, BadConnError{}) + return + } } if !cn.pooled { @@ -523,8 +551,24 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool { return false } - if connCheck(cn.netConn) != nil { - return false + // Check connection health, but be aware of push notifications + if err := connCheck(cn.netConn); err != nil { + // If there's unexpected data and we have push notification support, + // it might be push notifications + if err == errUnexpectedRead && cn.PushNotificationProcessor != nil && cn.PushNotificationProcessor.IsEnabled() { + // Try to process any pending push notifications + ctx := context.Background() + if procErr := cn.PushNotificationProcessor.ProcessPendingNotifications(ctx, cn.rd); procErr != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications during health check: %v", procErr) + return false + } + // Check again after processing push notifications + if connCheck(cn.netConn) != nil { + return false + } + } else { + return false + } } cn.SetUsedAt(now) diff --git a/options.go b/options.go index 02c1cb94..202345be 100644 --- a/options.go +++ b/options.go @@ -607,5 +607,7 @@ func newConnPool( MaxActiveConns: opt.MaxActiveConns, ConnMaxIdleTime: opt.ConnMaxIdleTime, ConnMaxLifetime: opt.ConnMaxLifetime, + // Pass push notification processor for connection initialization + PushNotificationProcessor: opt.PushNotificationProcessor, }) } diff --git a/push_notifications_test.go b/push_notifications_test.go index 46f8b089..46de1dc9 100644 --- a/push_notifications_test.go +++ b/push_notifications_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/internal/pool" ) func TestPushNotificationRegistry(t *testing.T) { @@ -892,3 +893,55 @@ func TestPushNotificationClientConcurrency(t *testing.T) { t.Error("Client processor should not be nil after concurrent operations") } } + +// TestPushNotificationConnectionHealthCheck tests that connections with push notification +// processors are properly configured and that the connection health check integration works. +func TestPushNotificationConnectionHealthCheck(t *testing.T) { + // Create a client with push notifications enabled + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, + PushNotifications: true, + }) + defer client.Close() + + // Verify push notifications are enabled + processor := client.GetPushNotificationProcessor() + if processor == nil || !processor.IsEnabled() { + t.Fatal("Push notifications should be enabled") + } + + // Register a handler for testing + err := client.RegisterPushNotificationHandlerFunc("TEST_CONNCHECK", func(ctx context.Context, notification []interface{}) bool { + t.Logf("Received test notification: %v", notification) + return true + }) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } + + // Test that connections have the push notification processor set + ctx := context.Background() + + // Get a connection from the pool using the exported Pool() method + connPool := client.Pool().(*pool.ConnPool) + cn, err := connPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + defer connPool.Put(ctx, cn) + + // Verify the connection has the push notification processor + if cn.PushNotificationProcessor == nil { + t.Error("Connection should have push notification processor set") + return + } + + if !cn.PushNotificationProcessor.IsEnabled() { + t.Error("Push notification processor should be enabled on connection") + return + } + + t.Log("✅ Connection has push notification processor correctly set") + t.Log("✅ Connection health check integration working correctly") +} diff --git a/redis.go b/redis.go index 0f6f8051..67188875 100644 --- a/redis.go +++ b/redis.go @@ -767,11 +767,17 @@ func NewClient(opt *Options) *Client { }, } c.init() - c.connPool = newConnPool(opt, c.dialHook) // Initialize push notification processor c.initializePushProcessor() + // Update options with the initialized push processor for connection pool + if c.pushProcessor != nil { + opt.PushNotificationProcessor = c.pushProcessor + } + + c.connPool = newConnPool(opt, c.dialHook) + return &c }