mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Basic implementation of request scheduling * Basic editing in SD and Flux Pipelines * Small Fix * Fix * Update for more pipelines * Add examples/server-async * Add examples/server-async * Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions * Fix * Fix _TokenizerLockWrapper * Fix _TokenizerLockWrapper * Delete _TokenizerLockWrapper * Fix tokenizer * Update examples/server-async * Fix server-async * Optimizations in examples/server-async * We keep the implementation simple in examples/server-async * Update examples/server-async/README.md * Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps * The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async * Update examples/server-async/utils/* * Fix BaseAsyncScheduler * Rollback in the core of the diffusers * Update examples/server-async/README.md * Complete rollback of diffusers core files * Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines * Update examples/server-async/README.md * Fixed import errors in 'examples/server-async/serverasync.py' * Flux Pipeline Discard * Update examples/server-async/README.md * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
297 lines
12 KiB
Python
297 lines
12 KiB
Python
import copy
|
|
import threading
|
|
from typing import Any, Iterable, List, Optional
|
|
|
|
import torch
|
|
|
|
from diffusers.utils import logging
|
|
|
|
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
|
with lock:
|
|
return tokenizer(*args, **kwargs)
|
|
|
|
|
|
class RequestScopedPipeline:
|
|
DEFAULT_MUTABLE_ATTRS = [
|
|
"_all_hooks",
|
|
"_offload_device",
|
|
"_progress_bar_config",
|
|
"_progress_bar",
|
|
"_rng_state",
|
|
"_last_seed",
|
|
"latents",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
pipeline: Any,
|
|
mutable_attrs: Optional[Iterable[str]] = None,
|
|
auto_detect_mutables: bool = True,
|
|
tensor_numel_threshold: int = 1_000_000,
|
|
tokenizer_lock: Optional[threading.Lock] = None,
|
|
wrap_scheduler: bool = True,
|
|
):
|
|
self._base = pipeline
|
|
self.unet = getattr(pipeline, "unet", None)
|
|
self.vae = getattr(pipeline, "vae", None)
|
|
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
|
self.components = getattr(pipeline, "components", None)
|
|
|
|
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
|
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
|
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
|
|
|
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
|
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
|
|
|
self._auto_detect_mutables = bool(auto_detect_mutables)
|
|
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
|
|
|
self._auto_detected_attrs: List[str] = []
|
|
|
|
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
|
base_sched = getattr(self._base, "scheduler", None)
|
|
if base_sched is None:
|
|
return None
|
|
|
|
if not isinstance(base_sched, BaseAsyncScheduler):
|
|
wrapped_scheduler = BaseAsyncScheduler(base_sched)
|
|
else:
|
|
wrapped_scheduler = base_sched
|
|
|
|
try:
|
|
return wrapped_scheduler.clone_for_request(
|
|
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
|
)
|
|
except Exception as e:
|
|
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
|
try:
|
|
return copy.deepcopy(wrapped_scheduler)
|
|
except Exception as e:
|
|
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
|
return wrapped_scheduler
|
|
|
|
def _autodetect_mutables(self, max_attrs: int = 40):
|
|
if not self._auto_detect_mutables:
|
|
return []
|
|
|
|
if self._auto_detected_attrs:
|
|
return self._auto_detected_attrs
|
|
|
|
candidates: List[str] = []
|
|
seen = set()
|
|
for name in dir(self._base):
|
|
if name.startswith("__"):
|
|
continue
|
|
if name in self._mutable_attrs:
|
|
continue
|
|
if name in ("to", "save_pretrained", "from_pretrained"):
|
|
continue
|
|
try:
|
|
val = getattr(self._base, name)
|
|
except Exception:
|
|
continue
|
|
|
|
import types
|
|
|
|
# skip callables and modules
|
|
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
|
continue
|
|
|
|
# containers -> candidate
|
|
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
|
candidates.append(name)
|
|
seen.add(name)
|
|
else:
|
|
# try Tensor detection
|
|
try:
|
|
if isinstance(val, torch.Tensor):
|
|
if val.numel() <= self._tensor_numel_threshold:
|
|
candidates.append(name)
|
|
seen.add(name)
|
|
else:
|
|
logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
|
|
except Exception:
|
|
continue
|
|
|
|
if len(candidates) >= max_attrs:
|
|
break
|
|
|
|
self._auto_detected_attrs = candidates
|
|
logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
|
|
return self._auto_detected_attrs
|
|
|
|
def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
|
|
try:
|
|
cls = type(base_obj)
|
|
descriptor = getattr(cls, attr_name, None)
|
|
if isinstance(descriptor, property):
|
|
return descriptor.fset is None
|
|
if hasattr(descriptor, "__set__") is False and descriptor is not None:
|
|
return False
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
def _clone_mutable_attrs(self, base, local):
|
|
attrs_to_clone = list(self._mutable_attrs)
|
|
attrs_to_clone.extend(self._autodetect_mutables())
|
|
|
|
EXCLUDE_ATTRS = {
|
|
"components",
|
|
}
|
|
|
|
for attr in attrs_to_clone:
|
|
if attr in EXCLUDE_ATTRS:
|
|
logger.debug(f"Skipping excluded attr '{attr}'")
|
|
continue
|
|
if not hasattr(base, attr):
|
|
continue
|
|
if self._is_readonly_property(base, attr):
|
|
logger.debug(f"Skipping read-only property '{attr}'")
|
|
continue
|
|
|
|
try:
|
|
val = getattr(base, attr)
|
|
except Exception as e:
|
|
logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
|
|
continue
|
|
|
|
try:
|
|
if isinstance(val, dict):
|
|
setattr(local, attr, dict(val))
|
|
elif isinstance(val, (list, tuple, set)):
|
|
setattr(local, attr, list(val))
|
|
elif isinstance(val, bytearray):
|
|
setattr(local, attr, bytearray(val))
|
|
else:
|
|
# small tensors or atomic values
|
|
if isinstance(val, torch.Tensor):
|
|
if val.numel() <= self._tensor_numel_threshold:
|
|
setattr(local, attr, val.clone())
|
|
else:
|
|
# don't clone big tensors, keep reference
|
|
setattr(local, attr, val)
|
|
else:
|
|
try:
|
|
setattr(local, attr, copy.copy(val))
|
|
except Exception:
|
|
setattr(local, attr, val)
|
|
except (AttributeError, TypeError) as e:
|
|
logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
|
|
continue
|
|
except Exception as e:
|
|
logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
|
|
continue
|
|
|
|
def _is_tokenizer_component(self, component) -> bool:
|
|
if component is None:
|
|
return False
|
|
|
|
tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
|
|
has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
|
|
|
|
class_name = component.__class__.__name__.lower()
|
|
has_tokenizer_in_name = "tokenizer" in class_name
|
|
|
|
tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
|
|
has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
|
|
|
|
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
|
|
|
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
|
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
|
|
|
try:
|
|
local_pipe = copy.copy(self._base)
|
|
except Exception as e:
|
|
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
|
local_pipe = copy.deepcopy(self._base)
|
|
|
|
if local_scheduler is not None:
|
|
try:
|
|
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
|
local_scheduler.scheduler,
|
|
num_inference_steps=num_inference_steps,
|
|
device=device,
|
|
return_scheduler=True,
|
|
**{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
|
|
)
|
|
|
|
final_scheduler = BaseAsyncScheduler(configured_scheduler)
|
|
setattr(local_pipe, "scheduler", final_scheduler)
|
|
except Exception:
|
|
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
|
|
|
|
self._clone_mutable_attrs(self._base, local_pipe)
|
|
|
|
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
|
tokenizer_wrappers = {} # name -> original_tokenizer
|
|
try:
|
|
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
|
for name in dir(local_pipe):
|
|
if "tokenizer" in name and not name.startswith("_"):
|
|
tok = getattr(local_pipe, name, None)
|
|
if tok is not None and self._is_tokenizer_component(tok):
|
|
tokenizer_wrappers[name] = tok
|
|
setattr(
|
|
local_pipe,
|
|
name,
|
|
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
|
tok, *args, lock=self._tokenizer_lock, **kwargs
|
|
),
|
|
)
|
|
|
|
# b) wrap tokenizers in components dict
|
|
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
|
for key, val in local_pipe.components.items():
|
|
if val is None:
|
|
continue
|
|
|
|
if self._is_tokenizer_component(val):
|
|
tokenizer_wrappers[f"components[{key}]"] = val
|
|
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
|
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
|
|
|
result = None
|
|
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
|
try:
|
|
if callable(cm):
|
|
try:
|
|
with cm():
|
|
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
|
except TypeError:
|
|
# cm might be a context manager instance rather than callable
|
|
try:
|
|
with cm:
|
|
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
|
except Exception as e:
|
|
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
|
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
|
else:
|
|
# no offload context available — call directly
|
|
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
|
|
|
return result
|
|
|
|
finally:
|
|
try:
|
|
for name, tok in tokenizer_wrappers.items():
|
|
if name.startswith("components["):
|
|
key = name[len("components[") : -1]
|
|
local_pipe.components[key] = tok
|
|
else:
|
|
setattr(local_pipe, name, tok)
|
|
except Exception as e:
|
|
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|