mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add thread-safe wrappers for components in pipeline (examples/server-async/utils/requestscopedpipeline.py) (#12515)
* Basic implementation of request scheduling * Basic editing in SD and Flux Pipelines * Small Fix * Fix * Update for more pipelines * Add examples/server-async * Add examples/server-async * Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions * Fix * Fix _TokenizerLockWrapper * Fix _TokenizerLockWrapper * Delete _TokenizerLockWrapper * Fix tokenizer * Update examples/server-async * Fix server-async * Optimizations in examples/server-async * We keep the implementation simple in examples/server-async * Update examples/server-async/README.md * Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps * The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async * Update examples/server-async/utils/* * Fix BaseAsyncScheduler * Rollback in the core of the diffusers * Update examples/server-async/README.md * Complete rollback of diffusers core files * Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines * Update examples/server-async/README.md * Fixed import errors in 'examples/server-async/serverasync.py' * Flux Pipeline Discard * Update examples/server-async/README.md * Apply style fixes * Add thread-safe wrappers for components in pipeline Refactor requestscopedpipeline.py to add thread-safe wrappers for tokenizer, VAE, and image processor. Introduce locking mechanisms to ensure thread safety during concurrent access. * Add wrappers.py * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -7,16 +7,12 @@ import torch
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
@@ -38,23 +34,40 @@ class RequestScopedPipeline:
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
self.transformer = getattr(pipeline, "transformer", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._vae_lock = threading.Lock()
|
||||
self._image_lock = threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _detect_kernel_pipeline(self, pipeline) -> bool:
|
||||
kernel_indicators = [
|
||||
"text_encoding_cache",
|
||||
"memory_manager",
|
||||
"enable_optimizations",
|
||||
"_create_request_context",
|
||||
"get_optimization_stats",
|
||||
]
|
||||
|
||||
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
@@ -70,11 +83,21 @@ class RequestScopedPipeline:
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
if hasattr(wrapped_scheduler, "scheduler"):
|
||||
try:
|
||||
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
|
||||
return BaseAsyncScheduler(copied_scheduler)
|
||||
except Exception:
|
||||
return wrapped_scheduler
|
||||
else:
|
||||
copied_scheduler = copy.copy(wrapped_scheduler)
|
||||
return BaseAsyncScheduler(copied_scheduler)
|
||||
except Exception as e2:
|
||||
logger.warning(
|
||||
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
|
||||
)
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
@@ -86,6 +109,7 @@ class RequestScopedPipeline:
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
@@ -93,6 +117,7 @@ class RequestScopedPipeline:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
@@ -100,11 +125,9 @@ class RequestScopedPipeline:
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
@@ -205,6 +228,9 @@ class RequestScopedPipeline:
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def _should_wrap_tokenizers(self) -> bool:
|
||||
return True
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
@@ -214,6 +240,25 @@ class RequestScopedPipeline:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
try:
|
||||
if (
|
||||
hasattr(local_pipe, "vae")
|
||||
and local_pipe.vae is not None
|
||||
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
|
||||
):
|
||||
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
|
||||
|
||||
if (
|
||||
hasattr(local_pipe, "image_processor")
|
||||
and local_pipe.image_processor is not None
|
||||
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
|
||||
):
|
||||
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
|
||||
local_pipe.image_processor, self._image_lock
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not wrap vae/image_processor: {e}")
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
@@ -231,47 +276,42 @@ class RequestScopedPipeline:
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
original_tokenizers = {}
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
if self._should_wrap_tokenizers():
|
||||
try:
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
if not isinstance(tok, ThreadSafeTokenizerWrapper):
|
||||
original_tokenizers[name] = tok
|
||||
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
|
||||
setattr(local_pipe, name, wrapped_tokenizer)
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
if self._is_tokenizer_component(val):
|
||||
if not isinstance(val, ThreadSafeTokenizerWrapper):
|
||||
original_tokenizers[f"components[{key}]"] = val
|
||||
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
|
||||
local_pipe.components[key] = wrapped_tokenizer
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
@@ -279,18 +319,18 @@ class RequestScopedPipeline:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
for name, tok in original_tokenizers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
logger.debug(f"Error restoring original tokenizers: {e}")
|
||||
|
||||
86
examples/server-async/utils/wrappers.py
Normal file
86
examples/server-async/utils/wrappers.py
Normal file
@@ -0,0 +1,86 @@
|
||||
class ThreadSafeTokenizerWrapper:
|
||||
def __init__(self, tokenizer, lock):
|
||||
self._tokenizer = tokenizer
|
||||
self._lock = lock
|
||||
|
||||
self._thread_safe_methods = {
|
||||
"__call__",
|
||||
"encode",
|
||||
"decode",
|
||||
"tokenize",
|
||||
"encode_plus",
|
||||
"batch_encode_plus",
|
||||
"batch_decode",
|
||||
}
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._tokenizer, name)
|
||||
|
||||
if name in self._thread_safe_methods and callable(attr):
|
||||
|
||||
def wrapped_method(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
return attr
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return self._tokenizer(*args, **kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._tokenizer, name, value)
|
||||
|
||||
def __dir__(self):
|
||||
return dir(self._tokenizer)
|
||||
|
||||
|
||||
class ThreadSafeVAEWrapper:
|
||||
def __init__(self, vae, lock):
|
||||
self._vae = vae
|
||||
self._lock = lock
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._vae, name)
|
||||
if name in {"decode", "encode", "forward"} and callable(attr):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._vae, name, value)
|
||||
|
||||
|
||||
class ThreadSafeImageProcessorWrapper:
|
||||
def __init__(self, proc, lock):
|
||||
self._proc = proc
|
||||
self._lock = lock
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._proc, name)
|
||||
if name in {"postprocess", "preprocess"} and callable(attr):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._proc, name, value)
|
||||
Reference in New Issue
Block a user