diff --git a/push_notifications.go b/push_notifications.go index 5dc44946..6777df00 100644 --- a/push_notifications.go +++ b/push_notifications.go @@ -96,17 +96,23 @@ func (r *PushNotificationRegistry) GetRegisteredPushNotificationNames() []string return names } -// HasHandlers returns true if there are any handlers registered. -func (r *PushNotificationRegistry) HasHandlers() bool { +// GetHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (r *PushNotificationRegistry) GetHandler(pushNotificationName string) PushNotificationHandler { r.mu.RLock() defer r.mu.RUnlock() - return len(r.handlers) > 0 + handler, exists := r.handlers[pushNotificationName] + if !exists { + return nil + } + return handler } // PushNotificationProcessorInterface defines the interface for push notification processors. type PushNotificationProcessorInterface interface { - GetRegistry() *PushNotificationRegistry + GetHandler(pushNotificationName string) PushNotificationHandler + GetRegistry() *PushNotificationRegistry // For backward compatibility and testing ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error RegisterHandler(pushNotificationName string, handler PushNotificationHandler, protected bool) error } @@ -123,16 +129,20 @@ func NewPushNotificationProcessor() *PushNotificationProcessor { } } -// GetRegistry returns the push notification registry. +// GetHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (p *PushNotificationProcessor) GetHandler(pushNotificationName string) PushNotificationHandler { + return p.registry.GetHandler(pushNotificationName) +} + +// GetRegistry returns the push notification registry for internal use. +// This method is primarily for testing and internal operations. func (p *PushNotificationProcessor) GetRegistry() *PushNotificationRegistry { return p.registry } // ProcessPendingNotifications checks for and processes any pending push notifications. func (p *PushNotificationProcessor) ProcessPendingNotifications(ctx context.Context, rd *proto.Reader) error { - if !p.registry.HasHandlers() { - return nil - } // Check if there are any buffered bytes that might contain push notifications if rd.Buffered() == 0 { @@ -233,6 +243,11 @@ func NewVoidPushNotificationProcessor() *VoidPushNotificationProcessor { return &VoidPushNotificationProcessor{} } +// GetHandler returns nil for void processor since it doesn't maintain handlers. +func (v *VoidPushNotificationProcessor) GetHandler(pushNotificationName string) PushNotificationHandler { + return nil +} + // GetRegistry returns nil for void processor since it doesn't maintain handlers. func (v *VoidPushNotificationProcessor) GetRegistry() *PushNotificationRegistry { return nil diff --git a/push_notifications_test.go b/push_notifications_test.go index 87ef8265..492c2734 100644 --- a/push_notifications_test.go +++ b/push_notifications_test.go @@ -28,9 +28,7 @@ func TestPushNotificationRegistry(t *testing.T) { registry := redis.NewPushNotificationRegistry() // Test initial state - if registry.HasHandlers() { - t.Error("Registry should not have handlers initially") - } + // Registry starts empty (no need to check HasHandlers anymore) commands := registry.GetRegisteredPushNotificationNames() if len(commands) != 0 { @@ -49,10 +47,7 @@ func TestPushNotificationRegistry(t *testing.T) { t.Fatalf("Failed to register handler: %v", err) } - if !registry.HasHandlers() { - t.Error("Registry should have handlers after registration") - } - + // Verify handler was registered by checking registered names commands = registry.GetRegisteredPushNotificationNames() if len(commands) != 1 || commands[0] != "TEST_COMMAND" { t.Errorf("Expected ['TEST_COMMAND'], got %v", commands) @@ -803,7 +798,6 @@ func TestPushNotificationRegistryConcurrency(t *testing.T) { registry.HandleNotification(context.Background(), notification) // Check registry state - registry.HasHandlers() registry.GetRegisteredPushNotificationNames() } }(i) @@ -815,10 +809,6 @@ func TestPushNotificationRegistryConcurrency(t *testing.T) { } // Verify registry is still functional - if !registry.HasHandlers() { - t.Error("Registry should have handlers after concurrent operations") - } - commands := registry.GetRegisteredPushNotificationNames() if len(commands) == 0 { t.Error("Registry should have registered commands after concurrent operations") diff --git a/redis.go b/redis.go index 5946e1ae..cd015daf 100644 --- a/redis.go +++ b/redis.go @@ -837,6 +837,12 @@ func (c *Client) GetPushNotificationProcessor() PushNotificationProcessorInterfa return c.pushProcessor } +// GetPushNotificationHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (c *Client) GetPushNotificationHandler(pushNotificationName string) PushNotificationHandler { + return c.pushProcessor.GetHandler(pushNotificationName) +} + type PoolStats pool.Stats // PoolStats returns connection pool stats. diff --git a/sentinel.go b/sentinel.go index df5742a3..948f3c97 100644 --- a/sentinel.go +++ b/sentinel.go @@ -522,6 +522,12 @@ func (c *SentinelClient) GetPushNotificationProcessor() PushNotificationProcesso return c.pushProcessor } +// GetPushNotificationHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (c *SentinelClient) GetPushNotificationHandler(pushNotificationName string) PushNotificationHandler { + return c.pushProcessor.GetHandler(pushNotificationName) +} + // RegisterPushNotificationHandler registers a handler for a specific push notification name. // Returns an error if a handler is already registered for this push notification name. // If protected is true, the handler cannot be unregistered.