From cb6d5fed19ce4672857d6dfbf95ba2848feea5b5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 18 Jun 2025 10:11:22 +0200 Subject: [PATCH] refator based on dhruv's feedbacks --- src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/modular_pipeline.py | 259 ++++++++++-------- src/diffusers/modular_pipelines/node_utils.py | 4 +- 3 files changed, 146 insertions(+), 121 deletions(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index cb2ed78ce3..8a23219761 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -23,7 +23,7 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["modular_pipeline"] = [ - "ModularPipelineMixin", + "ModularPipelineBlocks", "PipelineBlock", "AutoPipelineBlocks", "SequentialPipelineBlocks", @@ -53,7 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: BlockState, LoopSequentialPipelineBlocks, ModularLoader, - ModularPipelineMixin, + ModularPipelineBlocks, PipelineBlock, PipelineState, SequentialPipelineBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 3136c3bb11..5a93a29951 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -243,8 +243,7 @@ class BlockState: return f"BlockState(\n{attributes}\n)" - -class ModularPipelineMixin(ConfigMixin): +class ModularPipelineBlocks(ConfigMixin): """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ @@ -305,13 +304,10 @@ class ModularPipelineMixin(ConfigMixin): } return block_cls(**block_kwargs) - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ create a ModularLoader, optionally accept modular_repo to load from hub. """ - - # Import components loader (it is model-specific class) loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) @@ -322,98 +318,12 @@ class ModularPipelineMixin(ConfigMixin): # Create the loader with the updated specs specs = component_specs + config_specs - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + loader = loader_class(specs=specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + modular_pipeline = ModularPipeline(blocks=self, loader=loader) + return modular_pipeline - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.inputs: - params[input_param.name] = input_param.default - return params - - def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - if not hasattr(self, "loader"): - logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") - self.loader = None - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for expected_input_param in self.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediates_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.intermediates_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - @torch.compiler.disable - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - -class PipelineBlock(ModularPipelineMixin): +class PipelineBlock(ModularPipelineBlocks): model_name = None @@ -680,7 +590,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -class AutoPipelineBlocks(ModularPipelineMixin): +class AutoPipelineBlocks(ModularPipelineBlocks): """ A class that automatically selects a block to run based on the inputs. @@ -969,7 +879,8 @@ class AutoPipelineBlocks(ModularPipelineMixin): expected_configs=self.expected_configs ) -class SequentialPipelineBlocks(ModularPipelineMixin): + +class SequentialPipelineBlocks(ModularPipelineBlocks): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ @@ -1009,15 +920,24 @@ class SequentialPipelineBlocks(ModularPipelineMixin): """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. Args: - blocks_dict: Dictionary mapping block names to block instances + blocks_dict: Dictionary mapping block names to block classes or instances Returns: A new SequentialPipelineBlocks instance """ instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict + + # Create instances if classes are provided + blocks = {} + for name, block in blocks_dict.items(): + if inspect.isclass(block): + blocks[name] = block() + else: + blocks[name] = block + + instance.block_classes = [block.__class__ for block in blocks.values()] + instance.block_names = list(blocks.keys()) + instance.blocks = blocks return instance def __init__(self): @@ -1330,7 +1250,7 @@ class SequentialPipelineBlocks(ModularPipelineMixin): ) #YiYi TODO: __repr__ -class LoopSequentialPipelineBlocks(ModularPipelineMixin): +class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. """ @@ -1694,7 +1614,24 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin): return result + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs # YiYi TODO: @@ -1889,19 +1826,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin): return torch.device(module._hf_hook.execution_device) return self.device - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") @property def dtype(self) -> torch.dtype: @@ -2197,4 +2121,105 @@ class ModularLoader(ConfigMixin, PushToHubMixin): name=name, type_hint=type_hint, **spec_dict, - ) \ No newline at end of file + ) + + +class ModularPipeline: + """ + Base class for all Modular pipelines. + + Args: + blocks: ModularPipelineBlocks, the blocks to be used in the pipeline + loader: ModularLoader, the loader to be used in the pipeline + """ + + def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): + self.blocks = blocks + self.loader = loader + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.blocks.inputs: + params[input_param.name] = input_param.default + return params + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediates_inputs: + state.add_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.add_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.add_input(name, default, kwargs_type) + + for expected_intermediate_param in self.blocks.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self.blocks(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + + def load_components(self, component_names: Optional[List[str]] = None, **kwargs): + self.loader.load(component_names=component_names, **kwargs) + + def update_components(self, **kwargs): + self.loader.update(**kwargs) + + def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + loader = ModularLoader.from_pretrained(pretrained_model_name_or_path, **kwargs) + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, **kwargs) + return ModularPipeline(blocks=blocks, loader=loader) + + def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): + self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + + @property + def doc(self): + return self.blocks.doc \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 9ee9c06927..5f5e1c6c78 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -1,5 +1,5 @@ from ..configuration_utils import ConfigMixin -from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin +from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks from .modular_pipeline_utils import InputParam, OutputParam from ..image_processor import PipelineImageInput from pathlib import Path @@ -202,7 +202,7 @@ class ModularNode(ConfigMixin): trust_remote_code: Optional[bool] = None, **kwargs, ): - blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) return cls(blocks, **kwargs) def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):