mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
move save_pretrained to the correct place
This commit is contained in:
@@ -248,7 +248,7 @@ class BlockState:
|
||||
|
||||
class ModularPipelineBlocks(ConfigMixin):
|
||||
"""
|
||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
@@ -307,6 +307,20 @@ class ModularPipelineBlocks(ConfigMixin):
|
||||
}
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory, push_to_hub = False, **kwargs):
|
||||
# TODO: factor out this logic.
|
||||
cls_name = self.__class__.__name__
|
||||
|
||||
full_mod = type(self).__module__
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
|
||||
def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
@@ -532,21 +546,6 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(param_name, param, input_param.kwargs_type)
|
||||
|
||||
def save_pretrained(self, save_directory, push_to_hub = False, **kwargs):
|
||||
# TODO: factor out this logic.
|
||||
cls_name = self.__class__.__name__
|
||||
|
||||
full_mod = type(self).__module__
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
_component_names = [c.name for c in self.expected_components]
|
||||
|
||||
self.register_to_config(auto_map=auto_map, _component_names=_component_names)
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
"""
|
||||
@@ -2366,9 +2365,7 @@ class ModularPipeline:
|
||||
self.loader = loader
|
||||
|
||||
def __repr__(self):
|
||||
blocks_class = self.blocks.__class__.__name__
|
||||
loader_class = self.loader.__class__.__name__
|
||||
return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})"
|
||||
return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)"
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user