mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor component spec: replace create/create_from_pretrained/create_from_config to just create and load method
This commit is contained in:
@@ -71,34 +71,31 @@ class ComponentSpec:
|
||||
self.default_creation_method == other.default_creation_method)
|
||||
|
||||
@classmethod
|
||||
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
|
||||
"""Create a ComponentSpec from a Component created by `create` method."""
|
||||
def from_component(cls, name: str, component: Any) -> Any:
|
||||
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
|
||||
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
raise ValueError("Component is not created by `create` method")
|
||||
raise ValueError("Component is not created by `create` or `load` method")
|
||||
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
|
||||
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
|
||||
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
|
||||
raise ValueError(
|
||||
"We currently only support creating ComponentSpec from a component with "
|
||||
"created with `ComponentSpec.load` method"
|
||||
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
|
||||
)
|
||||
|
||||
type_hint = component.__class__
|
||||
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
|
||||
|
||||
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
|
||||
if isinstance(component, ConfigMixin):
|
||||
config = component.config
|
||||
else:
|
||||
config = None
|
||||
|
||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||
|
||||
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
|
||||
|
||||
@classmethod
|
||||
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
|
||||
"""Create a ComponentSpec from a load_id string."""
|
||||
if load_id == "null":
|
||||
raise ValueError("Cannot create ComponentSpec from null load_id")
|
||||
|
||||
# Decode the load_id into a dictionary of loading fields
|
||||
load_fields = cls.decode_load_id(load_id)
|
||||
|
||||
# Create a new ComponentSpec instance with the decoded fields
|
||||
return cls(name=name, **load_fields)
|
||||
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
|
||||
|
||||
@classmethod
|
||||
def loading_fields(cls) -> List[str]:
|
||||
@@ -137,7 +134,7 @@ class ComponentSpec:
|
||||
"revision": "revision"
|
||||
}
|
||||
If a segment value is "null", it's replaced with None.
|
||||
Returns None if load_id is "null" (indicating component not loaded from pretrained).
|
||||
Returns None if load_id is "null" (indicating component not created with `load` method).
|
||||
"""
|
||||
|
||||
# Get all loading fields in order
|
||||
@@ -158,20 +155,12 @@ class ComponentSpec:
|
||||
|
||||
return result
|
||||
|
||||
# YiYi TODO: add validator
|
||||
def create(self, **kwargs) -> Any:
|
||||
"""Create the component using the preferred creation method."""
|
||||
|
||||
# from_pretrained creation
|
||||
if self.default_creation_method == "from_pretrained":
|
||||
return self.create_from_pretrained(**kwargs)
|
||||
elif self.default_creation_method == "from_config":
|
||||
# from_config creation
|
||||
return self.create_from_config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
|
||||
|
||||
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
||||
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
|
||||
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
|
||||
# the config info is lost in the process
|
||||
# remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method
|
||||
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
||||
"""Create component using from_config with config."""
|
||||
|
||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||
@@ -201,34 +190,35 @@ class ComponentSpec:
|
||||
return component
|
||||
|
||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||
def create_from_pretrained(self, **kwargs) -> Any:
|
||||
"""Create component using from_pretrained."""
|
||||
def load(self, **kwargs) -> Any:
|
||||
"""Load component using from_pretrained."""
|
||||
|
||||
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||
repo = load_kwargs.pop("repo", None)
|
||||
if repo is None:
|
||||
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}")
|
||||
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||
# update type_hint if AutoModel load successfully
|
||||
self.type_hint = component.__class__
|
||||
else:
|
||||
try:
|
||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
|
||||
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||
|
||||
if repo != self.repo:
|
||||
self.repo = repo
|
||||
for k, v in passed_loading_kwargs.items():
|
||||
if v is not None:
|
||||
setattr(self, k, v)
|
||||
self.repo = repo
|
||||
for k, v in load_kwargs.items():
|
||||
setattr(self, k, v)
|
||||
component._diffusers_load_id = self.load_id
|
||||
|
||||
return component
|
||||
|
||||
Reference in New Issue
Block a user