1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Aryan
2025-08-14 06:44:31 +02:00
parent c777184215
commit 256d5a95be
4 changed files with 66 additions and 11 deletions

View File

@@ -215,6 +215,7 @@ else:
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"QwenImageTransformer2DModel",
@@ -243,6 +244,7 @@ else:
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
"enable_parallelism",
]
)
_import_structure["modular_pipelines"].extend(
@@ -879,6 +881,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageTransformer2DModel,
@@ -906,6 +909,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
enable_parallelism,
)
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .optimization import (

View File

@@ -25,7 +25,7 @@ from ..utils import (
_import_structure = {}
if is_torch_available():
_import_structure["_modeling_parallel"] = ["ParallelConfig"]
_import_structure["_modeling_parallel"] = ["ParallelConfig", "enable_parallelism"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
@@ -115,7 +115,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ._modeling_parallel import ParallelConfig
from ._modeling_parallel import ParallelConfig, enable_parallelism
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel

View File

@@ -15,14 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
from ..utils import get_logger
if TYPE_CHECKING:
from ..pipelines.pipeline_utils import DiffusionPipeline
from .modeling_utils import ModelMixin
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -117,3 +123,53 @@ ContextParallelOutputType = Union[
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
_ENABLE_PARALLELISM_WARN_ONCE = False
@contextlib.contextmanager
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]):
from diffusers import DiffusionPipeline, ModelMixin
from .attention_dispatch import _AttentionBackendRegistry
global _ENABLE_PARALLELISM_WARN_ONCE
if not _ENABLE_PARALLELISM_WARN_ONCE:
logger.warning(
"Support for `enable_parallelism` is experimental and the API may be subject to change in the future."
)
_ENABLE_PARALLELISM_WARN_ONCE = True
if isinstance(model_or_pipeline, DiffusionPipeline):
parallelized_components = [
(name, component)
for name, component in model_or_pipeline.components.items()
if getattr(component, "_internal_parallel_config", None) is not None
]
if len(parallelized_components) > 1:
raise ValueError(
"Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run "
"different stages of the pipeline separately with `enable_parallelism` on each component manually."
)
if len(parallelized_components) == 0:
raise ValueError(
"No parallelized components found in the pipeline. Please ensure at least one component is parallelized."
)
_, model_or_pipeline = parallelized_components[0]
elif isinstance(model_or_pipeline, ModelMixin):
if getattr(model_or_pipeline, "_internal_parallel_config", None) is None:
raise ValueError(
"The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager."
)
else:
raise TypeError(
f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline."
)
old_parallel_config = _AttentionBackendRegistry._parallel_config
_AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config
yield
_AttentionBackendRegistry._parallel_config = old_parallel_config

View File

@@ -249,6 +249,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []
_internal_parallel_config = None
_cp_plan = None
def __init__(self):
@@ -1480,10 +1481,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
@contextmanager
def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None):
from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel
from .attention_dispatch import _AttentionBackendRegistry
from ..hooks.context_parallel import apply_context_parallel
logger.warning(
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
@@ -1530,11 +1529,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, parallel_config, cp_plan)
_AttentionBackendRegistry._parallel_config = parallel_config
yield
remove_context_parallel(self, cp_plan)
self._internal_parallel_config = parallel_config
@classmethod
def _load_pretrained_model(