From 163341d3dd6c7ca8d375630a3b41363d1da3c9ce Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:58:26 +0200 Subject: [PATCH] refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load --- .../modular_pipelines/modular_pipeline.py | 124 ++++++++++++------ 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5dcb903db4..1c67a38717 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1651,54 +1651,68 @@ class ModularLoader(ConfigMixin, PushToHubMixin): def register_components(self, **kwargs): """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization + 4. Adds components to the component manager if one is attached + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + Args: **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue + # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) + # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - # actual library and class name of the module - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based oncomponent spec library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} + register_dict = {name: (library, class_name, component_spec_dict)} # set the component as attribute # if it is not set yet, just set it and skip the process to check and warn below if not is_registered: self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec setattr(self, name, module) - if module is not None and self._component_manager is not None: + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue @@ -1707,10 +1721,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin): if current_module is module: logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: @@ -1718,7 +1728,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): f"ModularLoader.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) - # same type, new instance → debug + # same type, new instance → replace but send debug log elif current_module is not None \ and module is not None \ and isinstance(module, current_module.__class__) \ @@ -1728,13 +1738,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin): f"(same type {type(current_module).__name__}, new instance)" ) - # save modular_model_index.json config + # update modular_model_index.json config self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec # finally set models setattr(self, name, module) - if module is not None and self._component_manager is not None: + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) @@ -1758,6 +1767,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): + # only update component_spec for from_pretrained components if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) @@ -1768,7 +1778,11 @@ class ModularLoader(ConfigMixin, PushToHubMixin): register_components_dict = {} for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component self.register_components(**register_components_dict) default_configs = {} @@ -1870,6 +1884,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): **kwargs: Component objects or configuration values to update: - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it Raises: ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) @@ -1893,22 +1908,52 @@ class ModularLoader(ConfigMixin, PushToHubMixin): unet=new_unet_model, requires_safety_checker=False ) + # update with ComponentSpec objects + loader.update( + guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + ) ``` """ # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + f"The passed component '{name}' is not supported in update() method " + f"because it is not supported in `ComponentSpec.from_component()`. " + f"Please pass a ComponentSpec object instead." + ) + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + self._component_specs[name] = new_component_spec if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - self.register_components(**passed_components) + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) config_to_register = {} @@ -1932,8 +1977,9 @@ class ModularLoader(ConfigMixin, PushToHubMixin): - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ + # if not specific name, load all the components with default_creation_method == "from_pretrained" if component_names is None: - component_names = list(self._component_specs.keys()) + component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] elif not isinstance(component_names, list): component_names = [component_names] @@ -1958,7 +2004,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): # check if the default is specified component_load_kwargs[key] = value["default"] try: - components_to_register[name] = spec.create(**component_load_kwargs) + components_to_register[name] = spec.load(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") @@ -1986,7 +2032,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) @@ -2010,7 +2056,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): else: # append a empty component spec for these not in modular_model_index component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) @staticmethod