1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

move save_pretrained to the correct place

This commit is contained in:
yiyixuxu
2025-06-25 08:55:06 +02:00
parent e49413d87d
commit ffbaa890ba

View File

@@ -248,7 +248,7 @@ class BlockState:
class ModularPipelineBlocks(ConfigMixin):
"""
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks
"""
config_name = "config.json"
@@ -307,6 +307,20 @@ class ModularPipelineBlocks(ConfigMixin):
}
return block_cls(**block_kwargs)
def save_pretrained(self, save_directory, push_to_hub = False, **kwargs):
# TODO: factor out this logic.
cls_name = self.__class__.__name__
full_mod = type(self).__module__
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
"""
@@ -532,21 +546,6 @@ class PipelineBlock(ModularPipelineBlocks):
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(param_name, param, input_param.kwargs_type)
def save_pretrained(self, save_directory, push_to_hub = False, **kwargs):
# TODO: factor out this logic.
cls_name = self.__class__.__name__
full_mod = type(self).__module__
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
_component_names = [c.name for c in self.expected_components]
self.register_to_config(auto_map=auto_map, _component_names=_component_names)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
@@ -2366,9 +2365,7 @@ class ModularPipeline:
self.loader = loader
def __repr__(self):
blocks_class = self.blocks.__class__.__name__
loader_class = self.loader.__class__.__name__
return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})"
return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)"
@property
def default_call_parameters(self) -> Dict[str, Any]: