diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c26a9c7c8a..cdb28519f4 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -248,7 +248,7 @@ class BlockState: class ModularPipelineBlocks(ConfigMixin): """ - Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks """ config_name = "config.json" @@ -307,6 +307,20 @@ class ModularPipelineBlocks(ConfigMixin): } return block_cls(**block_kwargs) + + def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): + # TODO: factor out this logic. + cls_name = self.__class__.__name__ + + full_mod = type(self).__module__ + module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + auto_map = {f"{parent_module}": f"{module}.{cls_name}"} + + self.register_to_config(auto_map=auto_map) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + self._internal_dict = FrozenDict(config) def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ @@ -532,21 +546,6 @@ class PipelineBlock(ModularPipelineBlocks): if current_value is not param: # Using identity comparison to check if object was modified state.add_intermediate(param_name, param, input_param.kwargs_type) - def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): - # TODO: factor out this logic. - cls_name = self.__class__.__name__ - - full_mod = type(self).__module__ - module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") - parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] - auto_map = {f"{parent_module}": f"{module}.{cls_name}"} - _component_names = [c.name for c in self.expected_components] - - self.register_to_config(auto_map=auto_map, _component_names=_component_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - self._internal_dict = FrozenDict(config) - def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ @@ -2366,9 +2365,7 @@ class ModularPipeline: self.loader = loader def __repr__(self): - blocks_class = self.blocks.__class__.__name__ - loader_class = self.loader.__class__.__name__ - return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})" + return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)" @property def default_call_parameters(self) -> Dict[str, Any]: