diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index 43d9ea8857..07fca44678 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -27,7 +27,7 @@ from ..utils import logging from . import BaseDiffusersCLICommand -EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"] +EXPECTED_PARENT_CLASSES = ["PipelineBlock"] CONFIG = "config.json" diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index b99478cb58..6607ea1efe 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -351,6 +351,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): "token", ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + hub_kwargs.update({"trust_remote_code": trust_remote_code}) config = cls.load_config(pretrained_model_name_or_path) has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] @@ -358,7 +359,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): trust_remote_code, pretrained_model_name_or_path, has_remote_code ) if not (has_remote_code and trust_remote_code): - raise ValueError("TODO") + raise ValueError( + "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file." + ) class_ref = config["auto_map"][cls.__name__] module_file, class_name = class_ref.split(".") @@ -367,7 +370,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): pretrained_model_name_or_path, module_file=module_file, class_name=class_name, - is_modular=True, **hub_kwargs, **kwargs, ) @@ -384,10 +386,39 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): full_mod = type(self).__module__ module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") - parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + parent_module = self.__class__.__bases__[0].__name__ auto_map = {f"{parent_module}": f"{module}.{cls_name}"} self.register_to_config(auto_map=auto_map) + + _component_specs = {spec.name: deepcopy(spec) for spec in self.expected_components} + _config_specs = {spec.name: deepcopy(spec) for spec in self.expected_configs} + + register_components_dict = {} + for name, component_spec in _component_specs.items(): + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + + # Since ModularPipelineBlocks can never have loaded components we set + # first two fields in the config dict to None + register_components_dict[name] = ( + None, + None, + { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + }, + ) + self.register_to_config(**register_components_dict) + + default_configs = {} + for name, config_spec in _config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) config = dict(self.config) self._internal_dict = FrozenDict(config) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 4fac5ef4f2..b63925df26 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -93,7 +93,7 @@ class ComponentSpec: config: Optional[FrozenDict] = None # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) - subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default="", metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True}) default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"