From 1ff0ded0e33222104d91287f469f6ffbd15db1d9 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 26 Jun 2025 20:38:30 +0300 Subject: [PATCH] feat: enforce single handler per notification type - Change PushNotificationRegistry to allow only one handler per command - RegisterHandler methods now return error if handler already exists - Update UnregisterHandler to remove handler by command only - Update all client methods to return errors for duplicate registrations - Update comprehensive test suite to verify single handler behavior - Add specific test for duplicate handler error scenarios This prevents handler conflicts and ensures predictable notification routing with clear error handling for registration conflicts. --- push_notifications.go | 50 +++++----- push_notifications_test.go | 194 ++++++++++++++++++++++--------------- redis.go | 12 ++- 3 files changed, 146 insertions(+), 110 deletions(-) diff --git a/push_notifications.go b/push_notifications.go index 70741116..cc1bae90 100644 --- a/push_notifications.go +++ b/push_notifications.go @@ -2,6 +2,7 @@ package redis import ( "context" + "fmt" "sync" "github.com/redis/go-redis/v9/internal" @@ -26,27 +27,29 @@ func (f PushNotificationHandlerFunc) HandlePushNotification(ctx context.Context, // PushNotificationRegistry manages handlers for different types of push notifications. type PushNotificationRegistry struct { mu sync.RWMutex - handlers map[string][]PushNotificationHandler // command -> handlers - global []PushNotificationHandler // global handlers for all notifications + handlers map[string]PushNotificationHandler // command -> single handler + global []PushNotificationHandler // global handlers for all notifications } // NewPushNotificationRegistry creates a new push notification registry. func NewPushNotificationRegistry() *PushNotificationRegistry { return &PushNotificationRegistry{ - handlers: make(map[string][]PushNotificationHandler), + handlers: make(map[string]PushNotificationHandler), global: make([]PushNotificationHandler, 0), } } // RegisterHandler registers a handler for a specific push notification command. -func (r *PushNotificationRegistry) RegisterHandler(command string, handler PushNotificationHandler) { +// Returns an error if a handler is already registered for this command. +func (r *PushNotificationRegistry) RegisterHandler(command string, handler PushNotificationHandler) error { r.mu.Lock() defer r.mu.Unlock() - if r.handlers[command] == nil { - r.handlers[command] = make([]PushNotificationHandler, 0) + if _, exists := r.handlers[command]; exists { + return fmt.Errorf("handler already registered for command: %s", command) } - r.handlers[command] = append(r.handlers[command], handler) + r.handlers[command] = handler + return nil } // RegisterGlobalHandler registers a handler that will receive all push notifications. @@ -57,19 +60,12 @@ func (r *PushNotificationRegistry) RegisterGlobalHandler(handler PushNotificatio r.global = append(r.global, handler) } -// UnregisterHandler removes a handler for a specific command. -func (r *PushNotificationRegistry) UnregisterHandler(command string, handler PushNotificationHandler) { +// UnregisterHandler removes the handler for a specific push notification command. +func (r *PushNotificationRegistry) UnregisterHandler(command string) { r.mu.Lock() defer r.mu.Unlock() - handlers := r.handlers[command] - for i, h := range handlers { - // Compare function pointers (this is a simplified approach) - if &h == &handler { - r.handlers[command] = append(handlers[:i], handlers[i+1:]...) - break - } - } + delete(r.handlers, command) } // HandleNotification processes a push notification by calling all registered handlers. @@ -96,12 +92,10 @@ func (r *PushNotificationRegistry) HandleNotification(ctx context.Context, notif } } - // Call specific handlers - if handlers, exists := r.handlers[command]; exists { - for _, handler := range handlers { - if handler.HandlePushNotification(ctx, notification) { - handled = true - } + // Call specific handler + if handler, exists := r.handlers[command]; exists { + if handler.HandlePushNotification(ctx, notification) { + handled = true } } @@ -207,8 +201,9 @@ func (p *PushNotificationProcessor) ProcessPendingNotifications(ctx context.Cont } // RegisterHandler is a convenience method to register a handler for a specific command. -func (p *PushNotificationProcessor) RegisterHandler(command string, handler PushNotificationHandler) { - p.registry.RegisterHandler(command, handler) +// Returns an error if a handler is already registered for this command. +func (p *PushNotificationProcessor) RegisterHandler(command string, handler PushNotificationHandler) error { + return p.registry.RegisterHandler(command, handler) } // RegisterGlobalHandler is a convenience method to register a global handler. @@ -217,8 +212,9 @@ func (p *PushNotificationProcessor) RegisterGlobalHandler(handler PushNotificati } // RegisterHandlerFunc is a convenience method to register a function as a handler. -func (p *PushNotificationProcessor) RegisterHandlerFunc(command string, handlerFunc func(ctx context.Context, notification []interface{}) bool) { - p.registry.RegisterHandler(command, PushNotificationHandlerFunc(handlerFunc)) +// Returns an error if a handler is already registered for this command. +func (p *PushNotificationProcessor) RegisterHandlerFunc(command string, handlerFunc func(ctx context.Context, notification []interface{}) bool) error { + return p.registry.RegisterHandler(command, PushNotificationHandlerFunc(handlerFunc)) } // RegisterGlobalHandlerFunc is a convenience method to register a function as a global handler. diff --git a/push_notifications_test.go b/push_notifications_test.go index 42e29874..2f868584 100644 --- a/push_notifications_test.go +++ b/push_notifications_test.go @@ -29,7 +29,10 @@ func TestPushNotificationRegistry(t *testing.T) { return true }) - registry.RegisterHandler("TEST_COMMAND", handler) + err := registry.RegisterHandler("TEST_COMMAND", handler) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } if !registry.HasHandlers() { t.Error("Registry should have handlers after registration") @@ -80,6 +83,19 @@ func TestPushNotificationRegistry(t *testing.T) { if !globalHandlerCalled { t.Error("Global handler should have been called") } + + // Test duplicate handler registration error + duplicateHandler := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { + return true + }) + err = registry.RegisterHandler("TEST_COMMAND", duplicateHandler) + if err == nil { + t.Error("Expected error when registering duplicate handler") + } + expectedError := "handler already registered for command: TEST_COMMAND" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error()) + } } func TestPushNotificationProcessor(t *testing.T) { @@ -92,7 +108,7 @@ func TestPushNotificationProcessor(t *testing.T) { // Test registering handlers handlerCalled := false - processor.RegisterHandlerFunc("CUSTOM_NOTIFICATION", func(ctx context.Context, notification []interface{}) bool { + err := processor.RegisterHandlerFunc("CUSTOM_NOTIFICATION", func(ctx context.Context, notification []interface{}) bool { handlerCalled = true if len(notification) < 2 { t.Error("Expected at least 2 elements in notification") @@ -104,6 +120,9 @@ func TestPushNotificationProcessor(t *testing.T) { } return true }) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } // Test global handler globalHandlerCalled := false @@ -157,10 +176,13 @@ func TestClientPushNotificationIntegration(t *testing.T) { // Test registering handlers through client handlerCalled := false - client.RegisterPushNotificationHandlerFunc("CUSTOM_EVENT", func(ctx context.Context, notification []interface{}) bool { + err := client.RegisterPushNotificationHandlerFunc("CUSTOM_EVENT", func(ctx context.Context, notification []interface{}) bool { handlerCalled = true return true }) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } // Test global handler through client globalHandlerCalled := false @@ -232,10 +254,13 @@ func TestPushNotificationEnabledClient(t *testing.T) { // Test registering a handler handlerCalled := false - client.RegisterPushNotificationHandlerFunc("TEST_NOTIFICATION", func(ctx context.Context, notification []interface{}) bool { + err := client.RegisterPushNotificationHandlerFunc("TEST_NOTIFICATION", func(ctx context.Context, notification []interface{}) bool { handlerCalled = true return true }) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } // Test that the handler works registry := processor.GetRegistry() @@ -318,11 +343,14 @@ func TestPubSubWithGenericPushNotifications(t *testing.T) { // Register a handler for custom push notifications customNotificationReceived := false - client.RegisterPushNotificationHandlerFunc("CUSTOM_PUBSUB_EVENT", func(ctx context.Context, notification []interface{}) bool { + err := client.RegisterPushNotificationHandlerFunc("CUSTOM_PUBSUB_EVENT", func(ctx context.Context, notification []interface{}) bool { customNotificationReceived = true t.Logf("Received custom push notification in PubSub context: %v", notification) return true }) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } // Create a PubSub instance pubsub := client.Subscribe(context.Background(), "test-channel") @@ -370,32 +398,28 @@ func TestPushNotificationMessageType(t *testing.T) { } func TestPushNotificationRegistryUnregisterHandler(t *testing.T) { - // Test unregistering handlers (note: current implementation has limitations with function pointer comparison) + // Test unregistering handlers registry := redis.NewPushNotificationRegistry() - // Register multiple handlers for the same command - handler1Called := false - handler1 := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { - handler1Called = true + // Register a handler + handlerCalled := false + handler := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { + handlerCalled = true return true }) - handler2Called := false - handler2 := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { - handler2Called = true - return true - }) + err := registry.RegisterHandler("TEST_CMD", handler) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } - registry.RegisterHandler("TEST_CMD", handler1) - registry.RegisterHandler("TEST_CMD", handler2) - - // Verify both handlers are registered + // Verify handler is registered commands := registry.GetRegisteredCommands() if len(commands) != 1 || commands[0] != "TEST_CMD" { t.Errorf("Expected ['TEST_CMD'], got %v", commands) } - // Test notification handling with both handlers + // Test notification handling ctx := context.Background() notification := []interface{}{"TEST_CMD", "data"} handled := registry.HandleNotification(ctx, notification) @@ -403,31 +427,32 @@ func TestPushNotificationRegistryUnregisterHandler(t *testing.T) { if !handled { t.Error("Notification should have been handled") } - if !handler1Called || !handler2Called { - t.Error("Both handlers should have been called") + if !handlerCalled { + t.Error("Handler should have been called") } - // Test that UnregisterHandler doesn't panic (even if it doesn't work perfectly) - registry.UnregisterHandler("TEST_CMD", handler1) - registry.UnregisterHandler("NON_EXISTENT", handler2) + // Test unregistering the handler + registry.UnregisterHandler("TEST_CMD") - // Note: Due to the current implementation using pointer comparison, - // unregistration may not work as expected. This test mainly verifies - // that the method doesn't panic and the registry remains functional. - - // Reset flags and test that handlers still work - handler1Called = false - handler2Called = false + // Verify handler is unregistered + commands = registry.GetRegisteredCommands() + if len(commands) != 0 { + t.Errorf("Expected no registered commands after unregister, got %v", commands) + } + // Reset flag and test that handler is no longer called + handlerCalled = false handled = registry.HandleNotification(ctx, notification) - if !handled { - t.Error("Notification should still be handled after unregister attempts") + + if handled { + t.Error("Notification should not be handled after unregistration") + } + if handlerCalled { + t.Error("Handler should not be called after unregistration") } - // The registry should still be functional - if !registry.HasHandlers() { - t.Error("Registry should still have handlers") - } + // Test unregistering non-existent handler (should not panic) + registry.UnregisterHandler("NON_EXISTENT") } func TestPushNotificationRegistryEdgeCases(t *testing.T) { @@ -453,51 +478,47 @@ func TestPushNotificationRegistryEdgeCases(t *testing.T) { } // Test unregistering non-existent handler - dummyHandler := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { - return true - }) - registry.UnregisterHandler("NON_EXISTENT", dummyHandler) + registry.UnregisterHandler("NON_EXISTENT") // Should not panic // Test unregistering from empty command - registry.UnregisterHandler("EMPTY_CMD", dummyHandler) + registry.UnregisterHandler("EMPTY_CMD") // Should not panic } -func TestPushNotificationRegistryMultipleHandlers(t *testing.T) { +func TestPushNotificationRegistryDuplicateHandlerError(t *testing.T) { registry := redis.NewPushNotificationRegistry() - // Test multiple handlers for the same command - handler1Called := false - handler2Called := false - handler3Called := false - - registry.RegisterHandler("MULTI_CMD", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { - handler1Called = true + // Test that registering duplicate handlers returns an error + handler1 := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { return true - })) + }) - registry.RegisterHandler("MULTI_CMD", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { - handler2Called = true - return false // Return false to test that other handlers still get called - })) + handler2 := redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { + return false + }) - registry.RegisterHandler("MULTI_CMD", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { - handler3Called = true - return true - })) - - // Test that all handlers are called - ctx := context.Background() - notification := []interface{}{"MULTI_CMD", "data"} - handled := registry.HandleNotification(ctx, notification) - - if !handled { - t.Error("Notification should be handled (at least one handler returned true)") + // Register first handler - should succeed + err := registry.RegisterHandler("DUPLICATE_CMD", handler1) + if err != nil { + t.Fatalf("First handler registration should succeed: %v", err) } - if !handler1Called || !handler2Called || !handler3Called { - t.Error("All handlers should have been called") + // Register second handler for same command - should fail + err = registry.RegisterHandler("DUPLICATE_CMD", handler2) + if err == nil { + t.Error("Second handler registration should fail") + } + + expectedError := "handler already registered for command: DUPLICATE_CMD" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error()) + } + + // Verify only one handler is registered + commands := registry.GetRegisteredCommands() + if len(commands) != 1 || commands[0] != "DUPLICATE_CMD" { + t.Errorf("Expected ['DUPLICATE_CMD'], got %v", commands) } } @@ -514,10 +535,13 @@ func TestPushNotificationRegistryGlobalAndSpecific(t *testing.T) { })) // Register specific handler - registry.RegisterHandler("SPECIFIC_CMD", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { + err := registry.RegisterHandler("SPECIFIC_CMD", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { specificCalled = true return true })) + if err != nil { + t.Fatalf("Failed to register specific handler: %v", err) + } // Test with specific command ctx := context.Background() @@ -602,7 +626,10 @@ func TestPushNotificationProcessorConvenienceMethods(t *testing.T) { return true }) - processor.RegisterHandler("CONV_CMD", handler) + err := processor.RegisterHandler("CONV_CMD", handler) + if err != nil { + t.Fatalf("Failed to register handler: %v", err) + } // Test RegisterGlobalHandler convenience method globalHandlerCalled := false @@ -615,10 +642,13 @@ func TestPushNotificationProcessorConvenienceMethods(t *testing.T) { // Test RegisterHandlerFunc convenience method funcHandlerCalled := false - processor.RegisterHandlerFunc("FUNC_CMD", func(ctx context.Context, notification []interface{}) bool { + err = processor.RegisterHandlerFunc("FUNC_CMD", func(ctx context.Context, notification []interface{}) bool { funcHandlerCalled = true return true }) + if err != nil { + t.Fatalf("Failed to register func handler: %v", err) + } // Test RegisterGlobalHandlerFunc convenience method globalFuncHandlerCalled := false @@ -669,18 +699,24 @@ func TestClientPushNotificationEdgeCases(t *testing.T) { }) defer client.Close() - // These should not panic even when processor is nil - client.RegisterPushNotificationHandler("TEST", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { + // These should not panic even when processor is nil and should return nil error + err := client.RegisterPushNotificationHandler("TEST", redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { return true })) + if err != nil { + t.Errorf("Expected nil error when processor is nil, got: %v", err) + } client.RegisterGlobalPushNotificationHandler(redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { return true })) - client.RegisterPushNotificationHandlerFunc("TEST_FUNC", func(ctx context.Context, notification []interface{}) bool { + err = client.RegisterPushNotificationHandlerFunc("TEST_FUNC", func(ctx context.Context, notification []interface{}) bool { return true }) + if err != nil { + t.Errorf("Expected nil error when processor is nil, got: %v", err) + } client.RegisterGlobalPushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { return true @@ -821,7 +857,7 @@ func TestPushNotificationRegistryConcurrency(t *testing.T) { defer func() { done <- true }() for j := 0; j < numOperations; j++ { - // Register handler + // Register handler (ignore errors in concurrency test) command := fmt.Sprintf("CMD_%d_%d", id, j) registry.RegisterHandler(command, redis.PushNotificationHandlerFunc(func(ctx context.Context, notification []interface{}) bool { return true @@ -876,7 +912,7 @@ func TestPushNotificationProcessorConcurrency(t *testing.T) { defer func() { done <- true }() for j := 0; j < numOperations; j++ { - // Register handlers + // Register handlers (ignore errors in concurrency test) command := fmt.Sprintf("PROC_CMD_%d_%d", id, j) processor.RegisterHandlerFunc(command, func(ctx context.Context, notification []interface{}) bool { return true @@ -930,7 +966,7 @@ func TestPushNotificationClientConcurrency(t *testing.T) { defer func() { done <- true }() for j := 0; j < numOperations; j++ { - // Register handlers concurrently + // Register handlers concurrently (ignore errors in concurrency test) command := fmt.Sprintf("CLIENT_CMD_%d_%d", id, j) client.RegisterPushNotificationHandlerFunc(command, func(ctx context.Context, notification []interface{}) bool { return true diff --git a/redis.go b/redis.go index 19167615..c7a6701e 100644 --- a/redis.go +++ b/redis.go @@ -814,10 +814,12 @@ func (c *Client) initializePushProcessor() { } // RegisterPushNotificationHandler registers a handler for a specific push notification command. -func (c *Client) RegisterPushNotificationHandler(command string, handler PushNotificationHandler) { +// Returns an error if a handler is already registered for this command. +func (c *Client) RegisterPushNotificationHandler(command string, handler PushNotificationHandler) error { if c.pushProcessor != nil { - c.pushProcessor.RegisterHandler(command, handler) + return c.pushProcessor.RegisterHandler(command, handler) } + return nil } // RegisterGlobalPushNotificationHandler registers a handler that will receive all push notifications. @@ -828,10 +830,12 @@ func (c *Client) RegisterGlobalPushNotificationHandler(handler PushNotificationH } // RegisterPushNotificationHandlerFunc registers a function as a handler for a specific push notification command. -func (c *Client) RegisterPushNotificationHandlerFunc(command string, handlerFunc func(ctx context.Context, notification []interface{}) bool) { +// Returns an error if a handler is already registered for this command. +func (c *Client) RegisterPushNotificationHandlerFunc(command string, handlerFunc func(ctx context.Context, notification []interface{}) bool) error { if c.pushProcessor != nil { - c.pushProcessor.RegisterHandlerFunc(command, handlerFunc) + return c.pushProcessor.RegisterHandlerFunc(command, handlerFunc) } + return nil } // RegisterGlobalPushNotificationHandlerFunc registers a function as a global handler for all push notifications.