From 162a15432b02a9b57d74003e5cf574267f66028a Mon Sep 17 00:00:00 2001
From: Bhargav Dodla <13788369+EXPEbdodla@users.noreply.github.com>
Date: Wed, 5 Mar 2025 12:08:27 -0800
Subject: [PATCH] fix: Fixed issue with context cancelled error leading to
 connection spikes on Primary instances (#3190)

* fix: Fixed issue with context cancelled error leading to connection spikes on Master

* fix: Added tests

* fix: Updated tests

---------

Co-authored-by: Bhargav Dodla <bdodla@expediagroup.com>
Co-authored-by: Nedyalko Dyakov <nedyalko.dyakov@gmail.com>
---
 error.go           |  9 +++++++++
 osscluster.go      |  4 +++-
 osscluster_test.go | 33 +++++++++++++++++++++++++++++++++
 3 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/error.go b/error.go
index 9b348193..a7bf159c 100644
--- a/error.go
+++ b/error.go
@@ -38,6 +38,15 @@ type Error interface {
 
 var _ Error = proto.RedisError("")
 
+func isContextError(err error) bool {
+	switch err {
+	case context.Canceled, context.DeadlineExceeded:
+		return true
+	default:
+		return false
+	}
+}
+
 func shouldRetry(err error, retryTimeout bool) bool {
 	switch err {
 	case io.EOF, io.ErrUnexpectedEOF:
diff --git a/osscluster.go b/osscluster.go
index 517fbd45..1e9ee7de 100644
--- a/osscluster.go
+++ b/osscluster.go
@@ -1350,7 +1350,9 @@ func (c *ClusterClient) processPipelineNode(
 	_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
 		cn, err := node.Client.getConn(ctx)
 		if err != nil {
-			node.MarkAsFailing()
+			if !isContextError(err) {
+				node.MarkAsFailing()
+			}
 			_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
 			setCmdsErr(cmds, err)
 			return err
diff --git a/osscluster_test.go b/osscluster_test.go
index aeb34c6b..ccf6daad 100644
--- a/osscluster_test.go
+++ b/osscluster_test.go
@@ -539,6 +539,39 @@ var _ = Describe("ClusterClient", func() {
 				AfterEach(func() {})
 
 				assertPipeline()
+
+				It("doesn't fail node with context.Canceled error", func() {
+					ctx, cancel := context.WithCancel(context.Background())
+					cancel()
+					pipe.Set(ctx, "A", "A_value", 0)
+					_, err := pipe.Exec(ctx)
+
+					Expect(err).To(HaveOccurred())
+					Expect(errors.Is(err, context.Canceled)).To(BeTrue())
+
+					clientNodes, _ := client.Nodes(ctx, "A")
+
+					for _, node := range clientNodes {
+						Expect(node.Failing()).To(BeFalse())
+					}
+				})
+
+				It("doesn't fail node with context.DeadlineExceeded error", func() {
+					ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
+					defer cancel()
+
+					pipe.Set(ctx, "A", "A_value", 0)
+					_, err := pipe.Exec(ctx)
+
+					Expect(err).To(HaveOccurred())
+					Expect(errors.Is(err, context.DeadlineExceeded)).To(BeTrue())
+
+					clientNodes, _ := client.Nodes(ctx, "A")
+
+					for _, node := range clientNodes {
+						Expect(node.Failing()).To(BeFalse())
+					}
+				})
 			})
 
 			Describe("with TxPipeline", func() {