mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
updatee modular_pipeline.from_pretrained, modular_repo ->pretrained_model_name_or_path
This commit is contained in:
@@ -264,6 +264,8 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"ModularLoader",
|
||||
"ModularPipeline",
|
||||
"ModularPipelineBlocks",
|
||||
"ComponentSpec",
|
||||
"ComponentsManager",
|
||||
]
|
||||
@@ -894,6 +896,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ModularLoader,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
ComponentsManager,
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
@@ -54,6 +55,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularLoader,
|
||||
ModularPipelineBlocks,
|
||||
ModularPipeline,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
|
||||
@@ -305,7 +305,7 @@ class ModularPipelineBlocks(ConfigMixin):
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
|
||||
def init_pipeline(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
create a ModularLoader, optionally accept modular_repo to load from hub.
|
||||
"""
|
||||
@@ -319,7 +319,7 @@ class ModularPipelineBlocks(ConfigMixin):
|
||||
# Create the loader with the updated specs
|
||||
specs = component_specs + config_specs
|
||||
|
||||
loader = loader_class(specs=specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection)
|
||||
loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection)
|
||||
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
|
||||
return modular_pipeline
|
||||
|
||||
@@ -1748,7 +1748,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
|
||||
|
||||
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
|
||||
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Initialize the loader with a list of component specs and config specs.
|
||||
"""
|
||||
@@ -1762,8 +1762,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
}
|
||||
|
||||
# update component_specs and config_specs from modular_repo
|
||||
if modular_repo is not None:
|
||||
config_dict = self.load_config(modular_repo, **kwargs)
|
||||
if pretrained_model_name_or_path is not None:
|
||||
config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for name, value in config_dict.items():
|
||||
# only update component_spec for from_pretrained components
|
||||
@@ -2215,10 +2215,12 @@ class ModularPipeline:
|
||||
def update_components(self, **kwargs):
|
||||
self.loader.update(**kwargs)
|
||||
|
||||
def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
loader = ModularLoader.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
return ModularPipeline(blocks=blocks, loader=loader)
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
||||
pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs)
|
||||
return pipeline
|
||||
|
||||
def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs):
|
||||
self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user