mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] gracefully error out when attn-backend x cp combo isn't supported. (#12832)
* gracefully error out when attn-backend x cp combo isn't supported.
* Revert "gracefully error out when attn-backend x cp combo isn't supported."
This reverts commit c8abb5d7c0.
* gracefully error out when attn-backend x cp combo isn't supported.
* up
* address PR feedback.
* up
* Update src/diffusers/models/modeling_utils.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* dot.
---------
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -235,6 +235,10 @@ class _AttentionBackendRegistry:
|
||||
def get_active_backend(cls):
|
||||
return cls._active_backend, cls._backends[cls._active_backend]
|
||||
|
||||
@classmethod
|
||||
def set_active_backend(cls, backend: str):
|
||||
cls._active_backend = backend
|
||||
|
||||
@classmethod
|
||||
def list_backends(cls):
|
||||
return list(cls._backends.keys())
|
||||
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
old_backend = _AttentionBackendRegistry._active_backend
|
||||
_AttentionBackendRegistry._active_backend = backend
|
||||
_AttentionBackendRegistry.set_active_backend(backend)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_AttentionBackendRegistry._active_backend = old_backend
|
||||
_AttentionBackendRegistry.set_active_backend(old_backend)
|
||||
|
||||
|
||||
def dispatch_attention_fn(
|
||||
@@ -348,6 +352,7 @@ def dispatch_attention_fn(
|
||||
check(**kwargs)
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
|
||||
|
||||
return backend_fn(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -599,6 +599,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import (
|
||||
AttentionBackendName,
|
||||
_AttentionBackendRegistry,
|
||||
_check_attention_backend_requirements,
|
||||
_maybe_download_kernel_for_backend,
|
||||
)
|
||||
@@ -607,6 +608,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
parallel_config_set = False
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if getattr(processor, "_parallel_config", None) is not None:
|
||||
parallel_config_set = True
|
||||
break
|
||||
|
||||
backend = backend.lower()
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
@@ -614,10 +625,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
|
||||
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||
raise ValueError(
|
||||
f"Context parallelism is enabled but current attention backend '{backend.value}' "
|
||||
f"does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`."
|
||||
)
|
||||
|
||||
_check_attention_backend_requirements(backend)
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
@@ -626,6 +644,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
continue
|
||||
processor._attention_backend = backend
|
||||
|
||||
# Important to set the active backend so that it propagates gracefully throughout.
|
||||
_AttentionBackendRegistry.set_active_backend(backend)
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
||||
@@ -1538,7 +1559,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||
f"calling `enable_parallelism()`."
|
||||
f"calling `model.enable_parallelism()`."
|
||||
)
|
||||
|
||||
# All modules use the same attention processor and backend. We don't need to
|
||||
|
||||
Reference in New Issue
Block a user