From d0fbf745e6e27185a8c465ced3373e2f77cf37e2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 May 2025 18:52:12 +0200 Subject: [PATCH] refactor component spec: replace create/create_from_pretrained/create_from_config to just create and load method --- .../modular_pipeline_utils.py | 72 ++++++++----------- 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index a82f83fc38..0c6d1b5855 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -71,34 +71,31 @@ class ComponentSpec: self.default_creation_method == other.default_creation_method) @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component created by `create` or `load` method.""" if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") + raise ValueError("Component is not created by `create` or `load` method") + # throw a error if component is created with `create` method but not a subclass of ConfigMixin + # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + "We currently only support creating ComponentSpec from a component with " + "created with `ComponentSpec.load` method" + "or created with `ComponentSpec.create` and a subclass of ConfigMixin" + ) type_hint = component.__class__ + default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + if isinstance(component, ConfigMixin): config = component.config else: config = None load_spec = cls.decode_load_id(component._diffusers_load_id) - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) @classmethod def loading_fields(cls) -> List[str]: @@ -137,7 +134,7 @@ class ComponentSpec: "revision": "revision" } If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). + Returns None if load_id is "null" (indicating component not created with `load` method). """ # Get all loading fields in order @@ -158,20 +155,12 @@ class ComponentSpec: return result - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" if self.type_hint is None or not isinstance(self.type_hint, type): @@ -201,34 +190,35 @@ class ComponentSpec: return component # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") if self.type_hint is None: try: from diffusers import AutoModel component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully self.type_hint = component.__class__ else: try: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} using load method: {e}") - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) component._diffusers_load_id = self.load_id return component