diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0053074bad..bedfe04ba4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -215,6 +215,7 @@ else: "MultiAdapter", "MultiControlNetModel", "OmniGenTransformer2DModel", + "ParallelConfig", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageTransformer2DModel", @@ -243,6 +244,7 @@ else: "WanTransformer3DModel", "WanVACETransformer3DModel", "attention_backend", + "enable_parallelism", ] ) _import_structure["modular_pipelines"].extend( @@ -879,6 +881,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: MultiAdapter, MultiControlNetModel, OmniGenTransformer2DModel, + ParallelConfig, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, @@ -906,6 +909,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: WanTransformer3DModel, WanVACETransformer3DModel, attention_backend, + enable_parallelism, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks from .optimization import ( diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index ca57ce27d4..c8289eed6d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -25,7 +25,7 @@ from ..utils import ( _import_structure = {} if is_torch_available(): - _import_structure["_modeling_parallel"] = ["ParallelConfig"] + _import_structure["_modeling_parallel"] = ["ParallelConfig", "enable_parallelism"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] @@ -115,7 +115,7 @@ if is_flax_available(): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from ._modeling_parallel import ParallelConfig + from ._modeling_parallel import ParallelConfig, enable_parallelism from .adapter import MultiAdapter, T2IAdapter from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4e62a6e5..495f3b10ce 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -15,14 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from dataclasses import dataclass -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import torch from ..utils import get_logger +if TYPE_CHECKING: + from ..pipelines.pipeline_utils import DiffusionPipeline + from .modeling_utils import ModelMixin + + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -117,3 +123,53 @@ ContextParallelOutputType = Union[ # A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of # the module should be split/gathered across context parallel region. ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]] + + +_ENABLE_PARALLELISM_WARN_ONCE = False + + +@contextlib.contextmanager +def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]): + from diffusers import DiffusionPipeline, ModelMixin + + from .attention_dispatch import _AttentionBackendRegistry + + global _ENABLE_PARALLELISM_WARN_ONCE + if not _ENABLE_PARALLELISM_WARN_ONCE: + logger.warning( + "Support for `enable_parallelism` is experimental and the API may be subject to change in the future." + ) + _ENABLE_PARALLELISM_WARN_ONCE = True + + if isinstance(model_or_pipeline, DiffusionPipeline): + parallelized_components = [ + (name, component) + for name, component in model_or_pipeline.components.items() + if getattr(component, "_internal_parallel_config", None) is not None + ] + if len(parallelized_components) > 1: + raise ValueError( + "Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run " + "different stages of the pipeline separately with `enable_parallelism` on each component manually." + ) + if len(parallelized_components) == 0: + raise ValueError( + "No parallelized components found in the pipeline. Please ensure at least one component is parallelized." + ) + _, model_or_pipeline = parallelized_components[0] + elif isinstance(model_or_pipeline, ModelMixin): + if getattr(model_or_pipeline, "_internal_parallel_config", None) is None: + raise ValueError( + "The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager." + ) + else: + raise TypeError( + f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline." + ) + + old_parallel_config = _AttentionBackendRegistry._parallel_config + _AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config + + yield + + _AttentionBackendRegistry._parallel_config = old_parallel_config diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9780002c5b..89c1e71a68 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -249,6 +249,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _skip_layerwise_casting_patterns = None _supports_group_offloading = True _repeated_blocks = [] + _internal_parallel_config = None _cp_plan = None def __init__(self): @@ -1480,10 +1481,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " ) - @contextmanager def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None): - from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel - from .attention_dispatch import _AttentionBackendRegistry + from ..hooks.context_parallel import apply_context_parallel logger.warning( "`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." @@ -1530,11 +1529,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): cp_plan = cp_plan if cp_plan is not None else self._cp_plan apply_context_parallel(self, parallel_config, cp_plan) - _AttentionBackendRegistry._parallel_config = parallel_config - - yield - - remove_context_parallel(self, cp_plan) + self._internal_parallel_config = parallel_config @classmethod def _load_pretrained_model(