1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refator based on dhruv's feedbacks

This commit is contained in:
yiyixuxu
2025-06-18 10:11:22 +02:00
parent f16e9c7807
commit cb6d5fed19
3 changed files with 146 additions and 121 deletions

View File

@@ -23,7 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineMixin",
"ModularPipelineBlocks",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
@@ -53,7 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
BlockState,
LoopSequentialPipelineBlocks,
ModularLoader,
ModularPipelineMixin,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,

View File

@@ -243,8 +243,7 @@ class BlockState:
return f"BlockState(\n{attributes}\n)"
class ModularPipelineMixin(ConfigMixin):
class ModularPipelineBlocks(ConfigMixin):
"""
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
"""
@@ -305,13 +304,10 @@ class ModularPipelineMixin(ConfigMixin):
}
return block_cls(**block_kwargs)
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
"""
create a ModularLoader, optionally accept modular_repo to load from hub.
"""
# Import components loader (it is model-specific class)
loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
diffusers_module = importlib.import_module("diffusers")
loader_class = getattr(diffusers_module, loader_class_name)
@@ -322,98 +318,12 @@ class ModularPipelineMixin(ConfigMixin):
# Create the loader with the updated specs
specs = component_specs + config_specs
self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
loader = loader_class(specs=specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
return modular_pipeline
@property
def default_call_parameters(self) -> Dict[str, Any]:
params = {}
for input_param in self.inputs:
params[input_param.name] = input_param.default
return params
def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
"""
if state is None:
state = PipelineState()
if not hasattr(self, "loader"):
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
self.loader = None
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
for expected_input_param in self.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediates_inputs:
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.add_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.add_input(name, default, kwargs_type)
for expected_intermediate_param in self.intermediates_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
pipeline, state = self(self.loader, state)
except Exception:
error_msg = f"Error in block: ({self.__class__.__name__}):\n"
logger.error(error_msg)
raise
if output is None:
return state
elif isinstance(output, str):
return state.get_intermediate(output)
elif isinstance(output, (list, tuple)):
return state.get_intermediates(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")
@torch.compiler.disable
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
class PipelineBlock(ModularPipelineMixin):
class PipelineBlock(ModularPipelineBlocks):
model_name = None
@@ -680,7 +590,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
return list(combined_dict.values())
class AutoPipelineBlocks(ModularPipelineMixin):
class AutoPipelineBlocks(ModularPipelineBlocks):
"""
A class that automatically selects a block to run based on the inputs.
@@ -969,7 +879,8 @@ class AutoPipelineBlocks(ModularPipelineMixin):
expected_configs=self.expected_configs
)
class SequentialPipelineBlocks(ModularPipelineMixin):
class SequentialPipelineBlocks(ModularPipelineBlocks):
"""
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
"""
@@ -1009,15 +920,24 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
Args:
blocks_dict: Dictionary mapping block names to block instances
blocks_dict: Dictionary mapping block names to block classes or instances
Returns:
A new SequentialPipelineBlocks instance
"""
instance = cls()
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
instance.block_names = list(blocks_dict.keys())
instance.blocks = blocks_dict
# Create instances if classes are provided
blocks = {}
for name, block in blocks_dict.items():
if inspect.isclass(block):
blocks[name] = block()
else:
blocks[name] = block
instance.block_classes = [block.__class__ for block in blocks.values()]
instance.block_names = list(blocks.keys())
instance.blocks = blocks
return instance
def __init__(self):
@@ -1330,7 +1250,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
)
#YiYi TODO: __repr__
class LoopSequentialPipelineBlocks(ModularPipelineMixin):
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""
A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
"""
@@ -1694,7 +1614,24 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
return result
@torch.compiler.disable
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
# YiYi TODO:
@@ -1889,19 +1826,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
return torch.device(module._hf_hook.execution_device)
return self.device
@property
def device(self) -> torch.device:
r"""
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)]
for module in modules:
return module.device
return torch.device("cpu")
@property
def dtype(self) -> torch.dtype:
@@ -2197,4 +2121,105 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
name=name,
type_hint=type_hint,
**spec_dict,
)
)
class ModularPipeline:
"""
Base class for all Modular pipelines.
Args:
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
loader: ModularLoader, the loader to be used in the pipeline
"""
def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader):
self.blocks = blocks
self.loader = loader
@property
def default_call_parameters(self) -> Dict[str, Any]:
params = {}
for input_param in self.blocks.inputs:
params[input_param.name] = input_param.default
return params
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
"""
if state is None:
state = PipelineState()
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediates_inputs:
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.add_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.add_input(name, default, kwargs_type)
for expected_intermediate_param in self.blocks.intermediates_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
pipeline, state = self.blocks(self.loader, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise
if output is None:
return state
elif isinstance(output, str):
return state.get_intermediate(output)
elif isinstance(output, (list, tuple)):
return state.get_intermediates(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")
def load_components(self, component_names: Optional[List[str]] = None, **kwargs):
self.loader.load(component_names=component_names, **kwargs)
def update_components(self, **kwargs):
self.loader.update(**kwargs)
def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
loader = ModularLoader.from_pretrained(pretrained_model_name_or_path, **kwargs)
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, **kwargs)
return ModularPipeline(blocks=blocks, loader=loader)
def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs):
self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@property
def doc(self):
return self.blocks.doc

View File

@@ -1,5 +1,5 @@
from ..configuration_utils import ConfigMixin
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
from .modular_pipeline_utils import InputParam, OutputParam
from ..image_processor import PipelineImageInput
from pathlib import Path
@@ -202,7 +202,7 @@ class ModularNode(ConfigMixin):
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):