1
0
mirror of https://github.com/redis/go-redis.git synced 2025-12-02 06:22:31 +03:00
Files
go-redis/internal/semaphore.go
Nedyalko Dyakov dc319c0f7e should properly notify the waiters
- this way a waiter that timesout at the same time
a releaser is releasing, won't throw token. the releaser
will fail to notify and will pick another waiter.

this hybrid approach should be faster than channels and maintains FIFO
2025-11-10 13:47:41 +02:00

333 lines
8.2 KiB
Go

package internal
import (
"context"
"sync"
"sync/atomic"
"time"
)
var semTimers = sync.Pool{
New: func() interface{} {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
// waiter represents a goroutine waiting for a token.
type waiter struct {
ready chan struct{}
next *waiter
cancelled atomic.Bool // Set to true if this waiter was cancelled/timed out
notified atomic.Bool // Set to true when Release() notifies this waiter
}
// FastSemaphore is a counting semaphore implementation using a hybrid approach.
// It's optimized for the fast path (no blocking) while still supporting timeouts and context cancellation.
//
// This implementation uses a buffered channel for the fast path (TryAcquire/Release without waiters)
// and a FIFO queue for waiters to ensure fairness.
//
// Performance characteristics:
// - Fast path (no blocking): Single channel operation (very fast)
// - Slow path (blocking): FIFO queue-based waiting
// - Release: Channel send or wake up first waiter in queue
//
// This is significantly faster than a pure channel-based semaphore because:
// 1. The fast path uses a buffered channel (single atomic operation)
// 2. FIFO ordering prevents starvation for waiters
// 3. Waiters don't compete with TryAcquire callers
type FastSemaphore struct {
// Buffered channel for fast path (TryAcquire/Release)
tokens chan struct{}
// Maximum number of tokens (capacity)
max int32
// Mutex to protect the waiter queue
lock sync.Mutex
// Head and tail of the waiter queue (FIFO)
head *waiter
tail *waiter
}
// NewFastSemaphore creates a new fast semaphore with the given capacity.
func NewFastSemaphore(capacity int32) *FastSemaphore {
ch := make(chan struct{}, capacity)
// Fill the channel with tokens (available slots)
for i := int32(0); i < capacity; i++ {
ch <- struct{}{}
}
return &FastSemaphore{
max: capacity,
tokens: ch,
}
}
// TryAcquire attempts to acquire a token without blocking.
// Returns true if successful, false if the semaphore is full.
//
// This is the fast path - just a single channel operation.
func (s *FastSemaphore) TryAcquire() bool {
select {
case <-s.tokens:
return true
default:
return false
}
}
// enqueue adds a waiter to the end of the queue.
// Must be called with lock held.
func (s *FastSemaphore) enqueue(w *waiter) {
if s.tail == nil {
s.head = w
s.tail = w
} else {
s.tail.next = w
s.tail = w
}
}
// dequeue removes and returns the first waiter from the queue.
// Must be called with lock held.
// Returns nil if the queue is empty.
func (s *FastSemaphore) dequeue() *waiter {
if s.head == nil {
return nil
}
w := s.head
s.head = w.next
if s.head == nil {
s.tail = nil
}
w.next = nil
return w
}
// notifyOne wakes up the first waiter in the queue if any.
func (s *FastSemaphore) notifyOne() {
s.lock.Lock()
w := s.dequeue()
s.lock.Unlock()
if w != nil {
close(w.ready)
}
}
// Acquire acquires a token, blocking if necessary until one is available or the context is cancelled.
// Returns an error if the context is cancelled or the timeout expires.
// Returns timeoutErr when the timeout expires.
func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
// Check context first
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Try fast path first (non-blocking channel receive)
select {
case <-s.tokens:
return nil
default:
// Channel is empty, need to wait
}
// Need to wait - create a waiter and add to queue
w := &waiter{
ready: make(chan struct{}),
}
s.lock.Lock()
s.enqueue(w)
s.lock.Unlock()
// Use timer pool to avoid allocation
timer := semTimers.Get().(*time.Timer)
defer semTimers.Put(timer)
timer.Reset(timeout)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
// Mark as cancelled and try to claim ourselves
w.cancelled.Store(true)
if w.notified.CompareAndSwap(false, true) {
// We successfully claimed ourselves, we're cancelling
// Try to remove from queue
s.lock.Lock()
removed := s.removeWaiter(w)
s.lock.Unlock()
if !removed {
// Already dequeued, wait for ready to be closed
<-w.ready
}
// We claimed it, so no token was given to us
return ctx.Err()
} else {
// Release() already claimed us and is giving us a token
// Wait for the notification and then release the token
<-w.ready
s.releaseToPool()
return ctx.Err()
}
case <-w.ready:
// We were notified and got the token
// Stop the timer and drain it if it already fired
if !timer.Stop() {
<-timer.C
}
// We have the token, just return
return nil
case <-timer.C:
// Mark as cancelled and try to claim ourselves
w.cancelled.Store(true)
if w.notified.CompareAndSwap(false, true) {
// We successfully claimed ourselves, we're cancelling
// Try to remove from queue
s.lock.Lock()
removed := s.removeWaiter(w)
s.lock.Unlock()
if !removed {
// Already dequeued, wait for ready to be closed
<-w.ready
}
// We claimed it, so no token was given to us
return timeoutErr
} else {
// Release() already claimed us and is giving us a token
// Wait for the notification and then release the token
<-w.ready
s.releaseToPool()
return timeoutErr
}
}
}
// removeWaiter removes a waiter from the queue.
// Must be called with lock held.
// Returns true if the waiter was found and removed, false otherwise.
func (s *FastSemaphore) removeWaiter(target *waiter) bool {
if s.head == nil {
return false
}
// Special case: removing head
if s.head == target {
s.head = target.next
if s.head == nil {
s.tail = nil
}
return true
}
// Find and remove from middle or tail
prev := s.head
for prev.next != nil {
if prev.next == target {
prev.next = target.next
if prev.next == nil {
s.tail = prev
}
return true
}
prev = prev.next
}
return false
}
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
// This is useful for cases where you don't need timeout or context cancellation.
// Returns immediately if a token is available (fast path).
func (s *FastSemaphore) AcquireBlocking() {
// Try fast path first (non-blocking channel receive)
select {
case <-s.tokens:
return
default:
// Channel is empty, need to wait
}
// Need to wait - create a waiter and add to queue
w := &waiter{
ready: make(chan struct{}),
}
s.lock.Lock()
s.enqueue(w)
s.lock.Unlock()
// Wait to be notified
<-w.ready
}
// releaseToPool releases a token back to the pool.
// This should be called when a waiter was notified but then cancelled/timed out.
// We need to pass the token to another waiter if any, otherwise put it back in the channel.
func (s *FastSemaphore) releaseToPool() {
s.lock.Lock()
w := s.dequeue()
s.lock.Unlock()
if w != nil {
// Transfer the token to another waiter
close(w.ready)
} else {
// No waiters, put the token back in the channel
s.tokens <- struct{}{}
}
}
// Release releases a token back to the semaphore.
// This wakes up the first waiting goroutine if any are blocked.
func (s *FastSemaphore) Release() {
// Try to give the token to a waiter first
for {
s.lock.Lock()
w := s.dequeue()
s.lock.Unlock()
if w == nil {
// No waiters, put the token back in the channel
s.tokens <- struct{}{}
return
}
// Check if this waiter was cancelled before we notify them
if w.cancelled.Load() {
// This waiter was cancelled, skip them and try the next one
// We still have the token, so continue the loop
close(w.ready) // Still need to close to unblock them
continue
}
// Try to claim this waiter by setting notified flag
// If the waiter is being cancelled concurrently, one of us will win
if !w.notified.CompareAndSwap(false, true) {
// Someone else (the waiter itself) already claimed it
// This means the waiter is cancelling, skip to next
close(w.ready) // Still need to close to unblock them
continue
}
// We successfully claimed the waiter, transfer the token
close(w.ready)
return
}
}
// Len returns the current number of acquired tokens.
// Used by tests to check semaphore state.
func (s *FastSemaphore) Len() int32 {
// Number of acquired tokens = max - available tokens in channel
return s.max - int32(len(s.tokens))
}