diff --git a/client_side_cache.go b/client_side_cache.go new file mode 100644 index 00000000..3cc99340 --- /dev/null +++ b/client_side_cache.go @@ -0,0 +1,354 @@ +package redis + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +// ClientSideCache represents a client-side cache with Redis invalidation support. +// It provides automatic cache invalidation through Redis CLIENT TRACKING and push notifications. +type ClientSideCache struct { + // Local cache storage + cache map[string]*cacheEntry + mu sync.RWMutex + + // Cache configuration + maxSize int + defaultTTL time.Duration + + // Redis client for operations and tracking + client *Client + + // Invalidation processing + invalidations chan []string + stopCh chan struct{} + wg sync.WaitGroup + + // Cache statistics + hits int64 + misses int64 + evictions int64 +} + +// cacheEntry represents a cached value with metadata +type cacheEntry struct { + Value interface{} + ExpiresAt time.Time + Key string + CreatedAt time.Time +} + +// ClientSideCacheOptions configures the client-side cache +type ClientSideCacheOptions struct { + // Cache size and TTL settings + MaxSize int // Maximum number of entries (default: 10000) + DefaultTTL time.Duration // Default TTL for cached entries (default: 5 minutes) + + // Redis client tracking options + EnableTracking bool // Enable Redis CLIENT TRACKING (default: true) + TrackingPrefix []string // Only track keys with these prefixes (optional) + NoLoop bool // Don't track keys modified by this client (default: true) + + // Cache behavior + InvalidationBufferSize int // Buffer size for invalidation channel (default: 1000) +} + +// NewClientSideCache creates a new client-side cache with Redis invalidation support. +// It automatically enables Redis CLIENT TRACKING and registers an invalidation handler. +func NewClientSideCache(client *Client, opts *ClientSideCacheOptions) (*ClientSideCache, error) { + if opts == nil { + opts = &ClientSideCacheOptions{ + MaxSize: 10000, + DefaultTTL: 5 * time.Minute, + EnableTracking: true, + NoLoop: true, + InvalidationBufferSize: 1000, + } + } + + // Set defaults for zero values + if opts.MaxSize <= 0 { + opts.MaxSize = 10000 + } + if opts.DefaultTTL <= 0 { + opts.DefaultTTL = 5 * time.Minute + } + if opts.InvalidationBufferSize <= 0 { + opts.InvalidationBufferSize = 1000 + } + + csc := &ClientSideCache{ + cache: make(map[string]*cacheEntry), + maxSize: opts.MaxSize, + defaultTTL: opts.DefaultTTL, + client: client, + invalidations: make(chan []string, opts.InvalidationBufferSize), + stopCh: make(chan struct{}), + } + + // Enable Redis client tracking + if opts.EnableTracking { + if err := csc.enableClientTracking(opts); err != nil { + return nil, err + } + } + + // Register invalidation handler + handler := &clientSideCacheInvalidationHandler{cache: csc} + if err := client.RegisterPushNotificationHandler("invalidate", handler, true); err != nil { + return nil, err + } + + // Start invalidation processor + csc.wg.Add(1) + go csc.processInvalidations() + + return csc, nil +} + +// enableClientTracking enables Redis CLIENT TRACKING for cache invalidation +func (csc *ClientSideCache) enableClientTracking(opts *ClientSideCacheOptions) error { + ctx := context.Background() + + // Build CLIENT TRACKING command + args := []interface{}{"CLIENT", "TRACKING", "ON"} + + if opts.NoLoop { + args = append(args, "NOLOOP") + } + + // If prefixes are specified, we need to use BCAST mode + if len(opts.TrackingPrefix) > 0 { + args = append(args, "BCAST") + for _, prefix := range opts.TrackingPrefix { + args = append(args, "PREFIX", prefix) + } + } + + // Enable tracking + cmd := csc.client.Do(ctx, args...) + return cmd.Err() +} + +// Get retrieves a value from the cache, falling back to Redis if not found. +// If the key is found in the local cache and not expired, it returns immediately. +// Otherwise, it fetches from Redis and stores the result in the local cache. +func (csc *ClientSideCache) Get(ctx context.Context, key string) *StringCmd { + // Try local cache first + if value, found := csc.getFromCache(key); found { + // Create a successful StringCmd with the cached value + cmd := NewStringCmd(ctx, "get", key) + cmd.SetVal(value.(string)) + return cmd + } + + // Cache miss - get from Redis + cmd := csc.client.Get(ctx, key) + if cmd.Err() == nil { + // Store successful result in local cache + csc.setInCache(key, cmd.Val(), csc.defaultTTL) + } + + return cmd +} + +// Set stores a value in Redis and updates the local cache. +// The value is first stored in Redis, and if successful, also cached locally. +func (csc *ClientSideCache) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd { + // Set in Redis first + cmd := csc.client.Set(ctx, key, value, expiration) + if cmd.Err() == nil { + // Update local cache on success + ttl := expiration + if ttl <= 0 { + ttl = csc.defaultTTL + } + csc.setInCache(key, value, ttl) + } + + return cmd +} + +// Del deletes keys from Redis and removes them from the local cache. +func (csc *ClientSideCache) Del(ctx context.Context, keys ...string) *IntCmd { + // Delete from Redis first + cmd := csc.client.Del(ctx, keys...) + if cmd.Err() == nil { + // Remove from local cache on success + csc.invalidateKeys(keys) + } + + return cmd +} + +// getFromCache retrieves a value from the local cache +func (csc *ClientSideCache) getFromCache(key string) (interface{}, bool) { + csc.mu.RLock() + defer csc.mu.RUnlock() + + entry, exists := csc.cache[key] + if !exists { + atomic.AddInt64(&csc.misses, 1) + return nil, false + } + + // Check expiration + if time.Now().After(entry.ExpiresAt) { + // Entry expired - remove it + delete(csc.cache, key) + atomic.AddInt64(&csc.misses, 1) + return nil, false + } + + atomic.AddInt64(&csc.hits, 1) + return entry.Value, true +} + +// setInCache stores a value in the local cache +func (csc *ClientSideCache) setInCache(key string, value interface{}, ttl time.Duration) { + csc.mu.Lock() + defer csc.mu.Unlock() + + // Check cache size limit + if len(csc.cache) >= csc.maxSize { + // Simple LRU eviction - remove oldest entry + csc.evictOldest() + } + + // Store entry + now := time.Now() + csc.cache[key] = &cacheEntry{ + Value: value, + ExpiresAt: now.Add(ttl), + Key: key, + CreatedAt: now, + } +} + +// evictOldest removes the oldest cache entry (simple LRU based on creation time) +func (csc *ClientSideCache) evictOldest() { + var oldestKey string + var oldestTime time.Time + + for key, entry := range csc.cache { + if oldestKey == "" || entry.CreatedAt.Before(oldestTime) { + oldestKey = key + oldestTime = entry.CreatedAt + } + } + + if oldestKey != "" { + delete(csc.cache, oldestKey) + atomic.AddInt64(&csc.evictions, 1) + } +} + +// processInvalidations processes cache invalidation notifications from Redis +func (csc *ClientSideCache) processInvalidations() { + defer csc.wg.Done() + + for { + select { + case keys := <-csc.invalidations: + csc.invalidateKeys(keys) + case <-csc.stopCh: + return + } + } +} + +// invalidateKeys removes specified keys from the local cache +func (csc *ClientSideCache) invalidateKeys(keys []string) { + if len(keys) == 0 { + return + } + + csc.mu.Lock() + defer csc.mu.Unlock() + + for _, key := range keys { + delete(csc.cache, key) + } +} + +// GetStats returns cache statistics +func (csc *ClientSideCache) GetStats() (hits, misses, evictions int64, hitRatio float64, size int) { + csc.mu.RLock() + size = len(csc.cache) + csc.mu.RUnlock() + + hits = atomic.LoadInt64(&csc.hits) + misses = atomic.LoadInt64(&csc.misses) + evictions = atomic.LoadInt64(&csc.evictions) + + total := hits + misses + if total > 0 { + hitRatio = float64(hits) / float64(total) + } + + return hits, misses, evictions, hitRatio, size +} + +// Clear removes all entries from the local cache +func (csc *ClientSideCache) Clear() { + csc.mu.Lock() + defer csc.mu.Unlock() + + csc.cache = make(map[string]*cacheEntry) +} + +// Close shuts down the client-side cache and disables Redis client tracking +func (csc *ClientSideCache) Close() error { + // Stop invalidation processor + close(csc.stopCh) + csc.wg.Wait() + + // Close invalidation channel + close(csc.invalidations) + + // Unregister invalidation handler + csc.client.UnregisterPushNotificationHandler("invalidate") + + // Disable Redis client tracking + ctx := context.Background() + return csc.client.Do(ctx, "CLIENT", "TRACKING", "OFF").Err() +} + +// clientSideCacheInvalidationHandler handles Redis invalidate push notifications +type clientSideCacheInvalidationHandler struct { + cache *ClientSideCache +} + +// HandlePushNotification processes invalidate notifications from Redis +func (h *clientSideCacheInvalidationHandler) HandlePushNotification(ctx context.Context, notification []interface{}) bool { + if len(notification) < 2 { + return false + } + + // Extract invalidated keys from the notification + // Format: ["invalidate", [key1, key2, ...]] + var keys []string + if keyList, ok := notification[1].([]interface{}); ok { + for _, k := range keyList { + if key, ok := k.(string); ok { + keys = append(keys, key) + } + } + } + + if len(keys) == 0 { + return false + } + + // Send to invalidation processor (non-blocking) + select { + case h.cache.invalidations <- keys: + return true + default: + // Channel full - invalidations will be dropped, but cache entries will eventually expire + // This is acceptable for performance reasons + return false + } +} diff --git a/client_side_cache_test.go b/client_side_cache_test.go new file mode 100644 index 00000000..e9e20874 --- /dev/null +++ b/client_side_cache_test.go @@ -0,0 +1,483 @@ +package redis + +import ( + "context" + "testing" + "time" +) + +func TestClientSideCache(t *testing.T) { + t.Run("NewClientSideCache", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, // Required for push notifications + }) + defer client.Close() + + // Test with default options + cache, err := NewClientSideCache(client, nil) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + if cache == nil { + t.Error("NewClientSideCache should return a non-nil cache") + } + if cache.client != client { + t.Error("Cache should reference the provided client") + } + if cache.maxSize != 10000 { + t.Errorf("Expected default maxSize 10000, got %d", cache.maxSize) + } + if cache.defaultTTL != 5*time.Minute { + t.Errorf("Expected default TTL 5 minutes, got %v", cache.defaultTTL) + } + }) + + t.Run("NewClientSideCacheWithOptions", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + opts := &ClientSideCacheOptions{ + MaxSize: 5000, + DefaultTTL: 10 * time.Minute, + EnableTracking: true, + NoLoop: false, + TrackingPrefix: []string{"user:", "session:"}, + InvalidationBufferSize: 500, + } + + cache, err := NewClientSideCache(client, opts) + if err != nil { + t.Fatalf("Failed to create client-side cache with options: %v", err) + } + defer cache.Close() + + if cache.maxSize != 5000 { + t.Errorf("Expected maxSize 5000, got %d", cache.maxSize) + } + if cache.defaultTTL != 10*time.Minute { + t.Errorf("Expected TTL 10 minutes, got %v", cache.defaultTTL) + } + }) + + t.Run("CacheOperations", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + cache, err := NewClientSideCache(client, &ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Minute, + EnableTracking: false, // Disable tracking for unit tests + }) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + ctx := context.Background() + + // Test cache miss and Redis fallback + key := "test:cache:key1" + value := "test_value" + + // First get should be a cache miss + _, misses, _, _, _ := cache.GetStats() + initialMisses := misses + + cmd := cache.Get(ctx, key) + if cmd.Err() == nil { + // If Redis is available and key exists, verify it's cached + hits, misses, _, _, size := cache.GetStats() + if misses <= initialMisses { + t.Error("Expected cache miss on first get") + } + if size > 0 && cmd.Err() == nil { + // Second get should be a cache hit + cmd2 := cache.Get(ctx, key) + if cmd2.Err() != nil { + t.Errorf("Second get failed: %v", cmd2.Err()) + } + hits2, _, _, _, _ := cache.GetStats() + if hits2 <= hits { + t.Error("Expected cache hit on second get") + } + } + } + + // Test Set operation + setCmd := cache.Set(ctx, key, value, time.Hour) + if setCmd.Err() == nil { + // Verify value is cached locally + if cachedValue, found := cache.getFromCache(key); found { + if cachedValue != value { + t.Errorf("Expected cached value %s, got %v", value, cachedValue) + } + } + } + + // Test Del operation + delCmd := cache.Del(ctx, key) + if delCmd.Err() == nil { + // Verify value is removed from cache + if _, found := cache.getFromCache(key); found { + t.Error("Key should be removed from cache after Del") + } + } + }) + + t.Run("CacheEviction", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + cache, err := NewClientSideCache(client, &ClientSideCacheOptions{ + MaxSize: 2, // Small cache for testing eviction + DefaultTTL: 1 * time.Hour, + EnableTracking: false, + }) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + // Fill cache beyond capacity + cache.setInCache("key1", "value1", time.Hour) + cache.setInCache("key2", "value2", time.Hour) + cache.setInCache("key3", "value3", time.Hour) // Should trigger eviction + + // Check that cache size doesn't exceed maxSize + _, _, _, _, size := cache.GetStats() + if size > 2 { + t.Errorf("Cache size %d exceeds maxSize 2", size) + } + + // Check that eviction occurred + _, _, evictions, _, _ := cache.GetStats() + if evictions == 0 { + t.Error("Expected at least one eviction") + } + }) + + t.Run("CacheExpiration", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + cache, err := NewClientSideCache(client, &ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Hour, + EnableTracking: false, + }) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + // Set entry with short TTL + cache.setInCache("expiring_key", "value", 1*time.Millisecond) + + // Wait for expiration + time.Sleep(10 * time.Millisecond) + + // Try to get expired entry + if value, found := cache.getFromCache("expiring_key"); found { + t.Errorf("Expected expired entry to be removed, but found: %v", value) + } + }) + + t.Run("InvalidationHandler", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + cache, err := NewClientSideCache(client, &ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Hour, + EnableTracking: false, + }) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + // Create invalidation handler + handler := &clientSideCacheInvalidationHandler{cache: cache} + + // Add some entries to cache + cache.setInCache("key1", "value1", time.Hour) + cache.setInCache("key2", "value2", time.Hour) + + // Test invalidation notification + ctx := context.Background() + notification := []interface{}{"invalidate", []interface{}{"key1", "key2"}} + + handled := handler.HandlePushNotification(ctx, notification) + if !handled { + t.Error("Handler should return true for valid invalidation notification") + } + + // Give some time for async processing + time.Sleep(10 * time.Millisecond) + + // Verify keys are removed from cache + if _, found := cache.getFromCache("key1"); found { + t.Error("key1 should be invalidated") + } + if _, found := cache.getFromCache("key2"); found { + t.Error("key2 should be invalidated") + } + }) + + t.Run("CacheStats", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + cache, err := NewClientSideCache(client, &ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Hour, + EnableTracking: false, + }) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + // Initial stats + hits, misses, evictions, hitRatio, size := cache.GetStats() + if hits != 0 || misses != 0 || evictions != 0 || hitRatio != 0 || size != 0 { + t.Error("Initial stats should be zero") + } + + // Generate some cache activity + cache.getFromCache("nonexistent") // miss + cache.setInCache("key1", "value1", time.Hour) + cache.getFromCache("key1") // hit + + hits, misses, evictions, hitRatio, size = cache.GetStats() + if hits != 1 { + t.Errorf("Expected 1 hit, got %d", hits) + } + if misses != 1 { + t.Errorf("Expected 1 miss, got %d", misses) + } + if size != 1 { + t.Errorf("Expected cache size 1, got %d", size) + } + if hitRatio != 0.5 { + t.Errorf("Expected hit ratio 0.5, got %f", hitRatio) + } + }) + + t.Run("CacheClear", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + cache, err := NewClientSideCache(client, &ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Hour, + EnableTracking: false, + }) + if err != nil { + t.Fatalf("Failed to create client-side cache: %v", err) + } + defer cache.Close() + + // Add entries + cache.setInCache("key1", "value1", time.Hour) + cache.setInCache("key2", "value2", time.Hour) + + // Verify entries exist + _, _, _, _, size := cache.GetStats() + if size != 2 { + t.Errorf("Expected cache size 2, got %d", size) + } + + // Clear cache + cache.Clear() + + // Verify cache is empty + _, _, _, _, size = cache.GetStats() + if size != 0 { + t.Errorf("Expected cache size 0 after clear, got %d", size) + } + }) +} + +func TestClientSideCacheIntegration(t *testing.T) { + t.Run("EnableDisableClientSideCache", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + // Initially no cache + if cache := client.GetClientSideCache(); cache != nil { + t.Error("Client should not have cache initially") + } + + // Enable cache + err := client.EnableClientSideCache(nil) + if err != nil { + t.Fatalf("Failed to enable client-side cache: %v", err) + } + + // Verify cache is enabled + if cache := client.GetClientSideCache(); cache == nil { + t.Error("Client should have cache after enabling") + } + + // Try to enable again (should fail) + err = client.EnableClientSideCache(nil) + if err == nil { + t.Error("Enabling cache twice should return error") + } + + // Disable cache + err = client.DisableClientSideCache() + if err != nil { + t.Errorf("Failed to disable client-side cache: %v", err) + } + + // Verify cache is disabled + if cache := client.GetClientSideCache(); cache != nil { + t.Error("Client should not have cache after disabling") + } + + // Disable again (should not error) + err = client.DisableClientSideCache() + if err != nil { + t.Errorf("Disabling cache twice should not error: %v", err) + } + }) + + t.Run("CachedOperations", func(t *testing.T) { + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + ctx := context.Background() + + // Test operations without cache (should fallback to regular operations) + cmd1 := client.CachedGet(ctx, "test:key") + cmd2 := client.CachedSet(ctx, "test:key", "value", time.Hour) + cmd3 := client.CachedDel(ctx, "test:key") + + // These should work the same as regular operations + if cmd1 == nil || cmd2 == nil || cmd3 == nil { + t.Error("Cached operations should return valid commands even without cache") + } + + // Enable cache + err := client.EnableClientSideCache(&ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Hour, + EnableTracking: false, // Disable for unit tests + }) + if err != nil { + t.Fatalf("Failed to enable client-side cache: %v", err) + } + defer client.DisableClientSideCache() + + // Test operations with cache + cmd4 := client.CachedGet(ctx, "test:key2") + cmd5 := client.CachedSet(ctx, "test:key2", "value2", time.Hour) + cmd6 := client.CachedDel(ctx, "test:key2") + + if cmd4 == nil || cmd5 == nil || cmd6 == nil { + t.Error("Cached operations should return valid commands with cache enabled") + } + + // Verify cache is being used + cache := client.GetClientSideCache() + if cache == nil { + t.Error("Cache should be available") + } + }) + + t.Run("CacheWithRealRedis", func(t *testing.T) { + // This test requires a real Redis instance + client := NewClient(&Options{ + Addr: "localhost:6379", + Protocol: 3, + }) + defer client.Close() + + // Test connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + t.Skip("Redis not available, skipping integration test") + } + + // Enable cache + err := client.EnableClientSideCache(&ClientSideCacheOptions{ + MaxSize: 100, + DefaultTTL: 1 * time.Minute, + EnableTracking: true, // Enable tracking for real Redis + NoLoop: true, + }) + if err != nil { + t.Fatalf("Failed to enable client-side cache: %v", err) + } + defer client.DisableClientSideCache() + + // Test key + key := "test:client_side_cache:integration" + value := "test_value_123" + + // Clean up + client.Del(ctx, key) + + // Set value + err = client.CachedSet(ctx, key, value, time.Hour).Err() + if err != nil { + t.Fatalf("Failed to set value: %v", err) + } + + // Get value (should be cached) + result, err := client.CachedGet(ctx, key).Result() + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + if result != value { + t.Errorf("Expected value %s, got %s", value, result) + } + + // Verify cache stats + cache := client.GetClientSideCache() + hits, misses, _, hitRatio, size := cache.GetStats() + if size == 0 { + t.Error("Cache should contain entries") + } + if hits+misses == 0 { + t.Error("Cache should have some activity") + } + + t.Logf("Cache stats: hits=%d, misses=%d, hitRatio=%.2f, size=%d", hits, misses, hitRatio, size) + + // Clean up + client.Del(ctx, key) + }) +} diff --git a/examples/client-side-cache/go.mod b/examples/client-side-cache/go.mod new file mode 100644 index 00000000..e06b9f45 --- /dev/null +++ b/examples/client-side-cache/go.mod @@ -0,0 +1,7 @@ +module client-side-cache-example + +go 1.21 + +require github.com/redis/go-redis/v9 v9.11.0 + +replace github.com/redis/go-redis/v9 => ../.. diff --git a/examples/client-side-cache/main.go b/examples/client-side-cache/main.go new file mode 100644 index 00000000..3750aae8 --- /dev/null +++ b/examples/client-side-cache/main.go @@ -0,0 +1,175 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/redis/go-redis/v9" +) + +func main() { + fmt.Println("=== Redis Client-Side Caching Example ===") + + // Create Redis client with RESP3 protocol (required for push notifications) + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // RESP3 required for client tracking and push notifications + }) + defer client.Close() + + ctx := context.Background() + + // Test Redis connection + if err := client.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis: %v", err) + } + fmt.Println("โœ… Connected to Redis") + + // Enable client-side caching + err := client.EnableClientSideCache(&redis.ClientSideCacheOptions{ + MaxSize: 1000, // Cache up to 1000 entries + DefaultTTL: 5 * time.Minute, // Default TTL for cached entries + EnableTracking: true, // Enable Redis CLIENT TRACKING + NoLoop: true, // Don't track keys modified by this client + TrackingPrefix: []string{"user:", "session:"}, // Only track specific prefixes + }) + if err != nil { + log.Fatalf("Failed to enable client-side cache: %v", err) + } + defer client.DisableClientSideCache() + fmt.Println("โœ… Client-side cache enabled") + + // Example 1: Basic caching operations + fmt.Println("\n๐Ÿ”ง Example 1: Basic Caching Operations") + + key := "user:123" + value := "John Doe" + + // Set a value (stored in Redis and cached locally) + err = client.CachedSet(ctx, key, value, time.Hour).Err() + if err != nil { + log.Fatalf("Failed to set value: %v", err) + } + fmt.Printf("โœ… Set %s = %s\n", key, value) + + // Get the value (should be served from cache) + result, err := client.CachedGet(ctx, key).Result() + if err != nil { + log.Fatalf("Failed to get value: %v", err) + } + fmt.Printf("โœ… Got %s = %s (from cache)\n", key, result) + + // Example 2: Cache statistics + fmt.Println("\n๐Ÿ”ง Example 2: Cache Statistics") + + cache := client.GetClientSideCache() + hits, misses, evictions, hitRatio, size := cache.GetStats() + fmt.Printf("๐Ÿ“Š Cache Stats:\n") + fmt.Printf(" Hits: %d\n", hits) + fmt.Printf(" Misses: %d\n", misses) + fmt.Printf(" Evictions: %d\n", evictions) + fmt.Printf(" Hit Ratio: %.2f%%\n", hitRatio*100) + fmt.Printf(" Size: %d entries\n", size) + + // Example 3: Multiple operations to show caching in action + fmt.Println("\n๐Ÿ”ง Example 3: Multiple Operations") + + keys := []string{"user:100", "user:101", "user:102"} + values := []string{"Alice", "Bob", "Charlie"} + + // Set multiple values + for i, k := range keys { + err = client.CachedSet(ctx, k, values[i], time.Hour).Err() + if err != nil { + log.Printf("Failed to set %s: %v", k, err) + } + } + fmt.Println("โœ… Set multiple user values") + + // Get values multiple times to show cache hits + for round := 1; round <= 3; round++ { + fmt.Printf("\n๐Ÿ“‹ Round %d - Getting values:\n", round) + for _, k := range keys { + start := time.Now() + result, err := client.CachedGet(ctx, k).Result() + duration := time.Since(start) + if err != nil { + log.Printf("Failed to get %s: %v", k, err) + continue + } + fmt.Printf(" %s = %s (took %v)\n", k, result, duration) + } + } + + // Show updated statistics + hits, misses, evictions, hitRatio, size = cache.GetStats() + fmt.Printf("\n๐Ÿ“Š Updated Cache Stats:\n") + fmt.Printf(" Hits: %d\n", hits) + fmt.Printf(" Misses: %d\n", misses) + fmt.Printf(" Evictions: %d\n", evictions) + fmt.Printf(" Hit Ratio: %.2f%%\n", hitRatio*100) + fmt.Printf(" Size: %d entries\n", size) + + // Example 4: Cache invalidation + fmt.Println("\n๐Ÿ”ง Example 4: Cache Invalidation") + + // Modify a value from another client to trigger invalidation + // (In a real scenario, this would be another application instance) + fmt.Println("๐Ÿ“‹ Simulating external modification...") + + // Use regular Set to modify the value (this will trigger invalidation) + err = client.Set(ctx, "user:100", "Alice Updated", time.Hour).Err() + if err != nil { + log.Printf("Failed to update value: %v", err) + } else { + fmt.Println("โœ… Updated user:100 externally") + + // Give some time for invalidation to process + time.Sleep(100 * time.Millisecond) + + // Get the value again (should fetch from Redis due to invalidation) + result, err := client.CachedGet(ctx, "user:100").Result() + if err != nil { + log.Printf("Failed to get updated value: %v", err) + } else { + fmt.Printf("โœ… Got updated value: %s\n", result) + } + } + + // Example 5: Cache management + fmt.Println("\n๐Ÿ”ง Example 5: Cache Management") + + // Clear the cache + cache.Clear() + fmt.Println("โœ… Cache cleared") + + // Show final statistics + hits, misses, evictions, hitRatio, size = cache.GetStats() + fmt.Printf("๐Ÿ“Š Final Cache Stats:\n") + fmt.Printf(" Hits: %d\n", hits) + fmt.Printf(" Misses: %d\n", misses) + fmt.Printf(" Evictions: %d\n", evictions) + fmt.Printf(" Hit Ratio: %.2f%%\n", hitRatio*100) + fmt.Printf(" Size: %d entries\n", size) + + // Clean up test keys + fmt.Println("\n๐Ÿงน Cleaning up...") + allKeys := append(keys, key) + deleted, err := client.Del(ctx, allKeys...).Result() + if err != nil { + log.Printf("Failed to clean up keys: %v", err) + } else { + fmt.Printf("โœ… Deleted %d keys\n", deleted) + } + + fmt.Println("\n๐ŸŽ‰ Client-Side Caching Example Complete!") + fmt.Println("\n๐Ÿ“‹ Key Benefits Demonstrated:") + fmt.Println(" โœ… Automatic local caching with Redis fallback") + fmt.Println(" โœ… Real-time cache invalidation via Redis push notifications") + fmt.Println(" โœ… Significant performance improvements for repeated reads") + fmt.Println(" โœ… Transparent integration with existing Redis operations") + fmt.Println(" โœ… Comprehensive statistics and monitoring") + fmt.Println(" โœ… Configurable cache size, TTL, and tracking options") +} diff --git a/internal/pushnotif/processor.go b/internal/pushnotif/processor.go index 4476ecb8..198226b1 100644 --- a/internal/pushnotif/processor.go +++ b/internal/pushnotif/processor.go @@ -127,8 +127,8 @@ func shouldSkipNotification(notificationType string) bool { "xread-from", // Stream reading notifications "xreadgroup-from", // Stream consumer group notifications - // Client tracking notifications - handled by client tracking system - "invalidate", // Client-side caching invalidation + // Client tracking notifications - handled by client-side cache system + // Note: "invalidate" is now handled by client-side cache, not filtered // Keyspace notifications - handled by keyspace notification subscribers // Note: Keyspace notifications typically have prefixes like "__keyspace@0__:" or "__keyevent@0__:" diff --git a/internal/pushnotif/pushnotif_test.go b/internal/pushnotif/pushnotif_test.go index 3fa84e88..cd9145c0 100644 --- a/internal/pushnotif/pushnotif_test.go +++ b/internal/pushnotif/pushnotif_test.go @@ -686,6 +686,7 @@ func TestShouldSkipNotification(t *testing.T) { "MIGRATED", // Cluster slot migration "FAILING_OVER", // Cluster failover "FAILED_OVER", // Cluster failover + "invalidate", // Client-side caching invalidation (now handled by cache) "unknown", // Unknown message type "", // Empty string "MESSAGE", // Case sensitive - should not match diff --git a/redis.go b/redis.go index b9e54fb8..bf1361f5 100644 --- a/redis.go +++ b/redis.go @@ -210,6 +210,9 @@ type baseClient struct { // Push notification processing pushProcessor PushNotificationProcessorInterface + + // Client-side cache for automatic caching with Redis invalidation + clientSideCache *ClientSideCache } func (c *baseClient) clone() *baseClient { @@ -835,11 +838,83 @@ func (c *Client) RegisterPushNotificationHandler(pushNotificationName string, ha return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) } +// UnregisterPushNotificationHandler unregisters a handler for a specific push notification name. +// Returns an error if no handler is registered for this push notification name or if the handler is protected. +func (c *Client) UnregisterPushNotificationHandler(pushNotificationName string) error { + // Check if we have a processor that supports unregistration + if processor, ok := c.pushProcessor.(interface { + UnregisterHandler(pushNotificationName string) error + }); ok { + return processor.UnregisterHandler(pushNotificationName) + } + return fmt.Errorf("push notification processor does not support unregistration") +} + // GetPushNotificationProcessor returns the push notification processor. func (c *Client) GetPushNotificationProcessor() PushNotificationProcessorInterface { return c.pushProcessor } +// EnableClientSideCache enables client-side caching with Redis invalidation support. +// This creates a local cache that automatically invalidates entries when Redis sends +// invalidation notifications through CLIENT TRACKING. +func (c *Client) EnableClientSideCache(opts *ClientSideCacheOptions) error { + if c.clientSideCache != nil { + return fmt.Errorf("client-side cache is already enabled") + } + + cache, err := NewClientSideCache(c, opts) + if err != nil { + return err + } + + c.clientSideCache = cache + return nil +} + +// DisableClientSideCache disables client-side caching and cleans up resources. +func (c *Client) DisableClientSideCache() error { + if c.clientSideCache == nil { + return nil // Already disabled + } + + err := c.clientSideCache.Close() + c.clientSideCache = nil + return err +} + +// GetClientSideCache returns the client-side cache if enabled, nil otherwise. +func (c *Client) GetClientSideCache() *ClientSideCache { + return c.clientSideCache +} + +// CachedGet retrieves a value using client-side caching if enabled, otherwise falls back to regular Get. +// This is a convenience method that automatically uses the cache when available. +func (c *Client) CachedGet(ctx context.Context, key string) *StringCmd { + if c.clientSideCache != nil { + return c.clientSideCache.Get(ctx, key) + } + return c.Get(ctx, key) +} + +// CachedSet stores a value using client-side caching if enabled, otherwise falls back to regular Set. +// This is a convenience method that automatically updates the cache when available. +func (c *Client) CachedSet(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd { + if c.clientSideCache != nil { + return c.clientSideCache.Set(ctx, key, value, expiration) + } + return c.Set(ctx, key, value, expiration) +} + +// CachedDel deletes keys using client-side caching if enabled, otherwise falls back to regular Del. +// This is a convenience method that automatically updates the cache when available. +func (c *Client) CachedDel(ctx context.Context, keys ...string) *IntCmd { + if c.clientSideCache != nil { + return c.clientSideCache.Del(ctx, keys...) + } + return c.Del(ctx, keys...) +} + // GetPushNotificationHandler returns the handler for a specific push notification name. // Returns nil if no handler is registered for the given name. func (c *Client) GetPushNotificationHandler(pushNotificationName string) PushNotificationHandler {