mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -27,7 +27,7 @@ from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
|
||||
EXPECTED_PARENT_CLASSES = ["PipelineBlock"]
|
||||
CONFIG = "config.json"
|
||||
|
||||
|
||||
|
||||
@@ -351,6 +351,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
hub_kwargs.update({"trust_remote_code": trust_remote_code})
|
||||
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
@@ -358,7 +359,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError("TODO")
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
@@ -367,7 +370,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
is_modular=True,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -384,10 +386,39 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
full_mod = type(self).__module__
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
parent_module = self.__class__.__bases__[0].__name__
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
|
||||
_component_specs = {spec.name: deepcopy(spec) for spec in self.expected_components}
|
||||
_config_specs = {spec.name: deepcopy(spec) for spec in self.expected_configs}
|
||||
|
||||
register_components_dict = {}
|
||||
for name, component_spec in _component_specs.items():
|
||||
if component_spec.type_hint is not None:
|
||||
lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint)
|
||||
else:
|
||||
lib_name = cls_name = None
|
||||
load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()}
|
||||
|
||||
# Since ModularPipelineBlocks can never have loaded components we set
|
||||
# first two fields in the config dict to None
|
||||
register_components_dict[name] = (
|
||||
None,
|
||||
None,
|
||||
{
|
||||
"type_hint": (lib_name, cls_name),
|
||||
**load_spec_dict,
|
||||
},
|
||||
)
|
||||
self.register_to_config(**register_components_dict)
|
||||
|
||||
default_configs = {}
|
||||
for name, config_spec in _config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
|
||||
@@ -93,7 +93,7 @@ class ComponentSpec:
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
Reference in New Issue
Block a user