mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor
This commit is contained in:
@@ -215,6 +215,7 @@ else:
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageTransformer2DModel",
|
||||
@@ -243,6 +244,7 @@ else:
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"attention_backend",
|
||||
"enable_parallelism",
|
||||
]
|
||||
)
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
@@ -879,6 +881,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
@@ -906,6 +909,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
attention_backend,
|
||||
enable_parallelism,
|
||||
)
|
||||
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
|
||||
from .optimization import (
|
||||
|
||||
@@ -25,7 +25,7 @@ from ..utils import (
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["_modeling_parallel"] = ["ParallelConfig"]
|
||||
_import_structure["_modeling_parallel"] = ["ParallelConfig", "enable_parallelism"]
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
@@ -115,7 +115,7 @@ if is_flax_available():
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
from ._modeling_parallel import ParallelConfig, enable_parallelism
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
|
||||
@@ -15,14 +15,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -117,3 +123,53 @@ ContextParallelOutputType = Union[
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
|
||||
|
||||
_ENABLE_PARALLELISM_WARN_ONCE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]):
|
||||
from diffusers import DiffusionPipeline, ModelMixin
|
||||
|
||||
from .attention_dispatch import _AttentionBackendRegistry
|
||||
|
||||
global _ENABLE_PARALLELISM_WARN_ONCE
|
||||
if not _ENABLE_PARALLELISM_WARN_ONCE:
|
||||
logger.warning(
|
||||
"Support for `enable_parallelism` is experimental and the API may be subject to change in the future."
|
||||
)
|
||||
_ENABLE_PARALLELISM_WARN_ONCE = True
|
||||
|
||||
if isinstance(model_or_pipeline, DiffusionPipeline):
|
||||
parallelized_components = [
|
||||
(name, component)
|
||||
for name, component in model_or_pipeline.components.items()
|
||||
if getattr(component, "_internal_parallel_config", None) is not None
|
||||
]
|
||||
if len(parallelized_components) > 1:
|
||||
raise ValueError(
|
||||
"Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run "
|
||||
"different stages of the pipeline separately with `enable_parallelism` on each component manually."
|
||||
)
|
||||
if len(parallelized_components) == 0:
|
||||
raise ValueError(
|
||||
"No parallelized components found in the pipeline. Please ensure at least one component is parallelized."
|
||||
)
|
||||
_, model_or_pipeline = parallelized_components[0]
|
||||
elif isinstance(model_or_pipeline, ModelMixin):
|
||||
if getattr(model_or_pipeline, "_internal_parallel_config", None) is None:
|
||||
raise ValueError(
|
||||
"The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline."
|
||||
)
|
||||
|
||||
old_parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
_AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config
|
||||
|
||||
yield
|
||||
|
||||
_AttentionBackendRegistry._parallel_config = old_parallel_config
|
||||
|
||||
@@ -249,6 +249,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
_internal_parallel_config = None
|
||||
_cp_plan = None
|
||||
|
||||
def __init__(self):
|
||||
@@ -1480,10 +1481,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None):
|
||||
from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel
|
||||
from .attention_dispatch import _AttentionBackendRegistry
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
|
||||
logger.warning(
|
||||
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
@@ -1530,11 +1529,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
apply_context_parallel(self, parallel_config, cp_plan)
|
||||
_AttentionBackendRegistry._parallel_config = parallel_config
|
||||
|
||||
yield
|
||||
|
||||
remove_context_parallel(self, cp_plan)
|
||||
self._internal_parallel_config = parallel_config
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
|
||||
Reference in New Issue
Block a user