mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
style
This commit is contained in:
@@ -12,19 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
import os
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import PushToHubMixin, get_logger
|
||||
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -221,8 +219,8 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the guider
|
||||
configuration saved with [`~BaseGuidance.save_pretrained`].
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
|
||||
saved with [`~BaseGuidance.save_pretrained`].
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
@@ -285,6 +283,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
|
||||
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
@@ -34,6 +35,7 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
@@ -47,8 +49,7 @@ from .modular_pipeline_utils import (
|
||||
format_intermediates_short,
|
||||
make_doc_string,
|
||||
)
|
||||
from huggingface_hub import create_repo
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
@@ -1670,16 +1671,21 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
This method sets up the pipeline by:
|
||||
1. creating default pipeline blocks if not provided
|
||||
2. gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs)
|
||||
3. update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided
|
||||
2. gather component and config specifications based on the pipeline blocks's requirement (e.g.
|
||||
expected_components, expected_configs)
|
||||
3. update the loading specs of from_pretrained components based on the modular_model_index.json file from
|
||||
huggingface hub if `pretrained_model_name_or_path` is provided
|
||||
4. create defaultfrom_config components and register everything
|
||||
|
||||
Args:
|
||||
blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
|
||||
default blocks based on the pipeline class name.
|
||||
pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
|
||||
will load component specs (only for from_pretrained components) and config values from the saved modular_model_index.json file.
|
||||
components_manager: Optional ComponentsManager for managing multiple component cross different pipelines and apply offloading strategies.
|
||||
will load component specs (only for from_pretrained components) and config values from the saved
|
||||
modular_model_index.json file.
|
||||
components_manager:
|
||||
Optional ComponentsManager for managing multiple component cross different pipelines and apply
|
||||
offloading strategies.
|
||||
collection: Optional collection name for organizing components in the ComponentsManager.
|
||||
**kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration.
|
||||
|
||||
@@ -1693,18 +1699,20 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Initialize with components manager
|
||||
pipeline = ModularPipeline(
|
||||
blocks=my_blocks,
|
||||
components_manager=ComponentsManager(),
|
||||
collection="my_collection"
|
||||
blocks=my_blocks, components_manager=ComponentsManager(), collection="my_collection"
|
||||
)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- If blocks is None, the method will try to find default blocks based on the pipeline class name
|
||||
- Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json`
|
||||
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with `load_default_components()`/`load_components()`
|
||||
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict
|
||||
- Components with default_creation_method="from_config" are created immediately, its specs are not included
|
||||
in config dict and will not be saved in `modular_model_index.json`
|
||||
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
|
||||
`load_default_components()`/`load_components()`
|
||||
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
|
||||
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
`_blocks_class_name` in the config dict
|
||||
"""
|
||||
if blocks is None:
|
||||
blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
|
||||
@@ -1769,12 +1777,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
Args:
|
||||
state (`PipelineState`, optional):
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be created based on the user inputs and the pipeline blocks's requirement.
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
|
||||
created based on the user inputs and the pipeline blocks's requirement.
|
||||
output (`str` or `List[str]`, optional):
|
||||
Optional specification of what to return:
|
||||
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
|
||||
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", "latents"]`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
|
||||
"latents"]`)
|
||||
|
||||
|
||||
Examples:
|
||||
@@ -1794,11 +1804,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
state = pipeline(prompt="A beautiful sunset")
|
||||
new_state = pipeline(state=state, output="image") # Continue processing
|
||||
```
|
||||
|
||||
|
||||
Returns:
|
||||
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
|
||||
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. `output=["image", "latents"]`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
|
||||
`output=["image", "latents"]`)
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
@@ -1880,11 +1891,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
|
||||
Path to a pretrained pipeline configuration. If provided, will load component specs (only for from_pretrained components) and config values from the modular_model_index.json file.
|
||||
Path to a pretrained pipeline configuration. If provided, will load component specs (only for
|
||||
from_pretrained components) and config values from the modular_model_index.json file.
|
||||
trust_remote_code (`bool`, optional):
|
||||
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
||||
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
|
||||
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
||||
components_manager (`ComponentsManager`, optional):
|
||||
ComponentsManager instance for managing multiple component cross different pipelines and apply offloading strategies.
|
||||
ComponentsManager instance for managing multiple component cross different pipelines and apply
|
||||
offloading strategies.
|
||||
collection (`str`, optional):`
|
||||
Collection name for organizing components in the ComponentsManager.
|
||||
"""
|
||||
@@ -1935,8 +1949,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
push_to_hub (`bool`, optional):
|
||||
Whether to push the pipeline to the huggingface hub.
|
||||
**kwargs: Additional arguments passed to `save_config()` method
|
||||
|
||||
|
||||
"""
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
@@ -1945,12 +1957,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
@@ -1977,7 +1989,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
This method is responsible for:
|
||||
1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
|
||||
2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only for from_pretrained components)
|
||||
2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only
|
||||
for from_pretrained components)
|
||||
3. Adds components to the component manager if one is attached (only for from_pretrained components)
|
||||
|
||||
This method is called when:
|
||||
@@ -1986,15 +1999,18 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
|
||||
loader.update_components(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. loader.load_default_components(names=["unet"])
|
||||
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
|
||||
loader.load_default_components(names=["unet"])
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
E.g., register_components(unet=unet_model, text_encoder=encoder_model)
|
||||
|
||||
Notes:
|
||||
- When registering None for a component, it sets attribute to None but still syncs specs with the config dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- component_specs are updated to match the new component outside of this method, e.g. in `update_components()` method
|
||||
- When registering None for a component, it sets attribute to None but still syncs specs with the config
|
||||
dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- component_specs are updated to match the new component outside of this method, e.g. in
|
||||
`update_components()` method
|
||||
"""
|
||||
for name, module in kwargs.items():
|
||||
# current component spec
|
||||
@@ -2166,7 +2182,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
def components(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns:
|
||||
- Dictionary mapping component names to their objects (include both from_pretrained and from_config components)
|
||||
- Dictionary mapping component names to their objects (include both from_pretrained and from_config
|
||||
components)
|
||||
"""
|
||||
# return only components we've actually set as attributes on self
|
||||
return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
|
||||
@@ -2186,19 +2203,21 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`)
|
||||
2. Update configuration values (e.g., changing `self.requires_safety_checker` flag)
|
||||
|
||||
In addition to updating the components and configuration values as pipeline attributes, the method also updates:
|
||||
In addition to updating the components and configuration values as pipeline attributes, the method also
|
||||
updates:
|
||||
- the corresponding specs in `_component_specs` and `_config_specs`
|
||||
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
|
||||
Args:
|
||||
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
|
||||
- Component objects: Only supports components we can extract specs using `ComponentSpec.from_component()` method
|
||||
i.e. components created with ComponentSpec.load() or ConfigMixin subclasses that aren't nn.Modules
|
||||
(e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() method to create a new component
|
||||
(e.g., `guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
|
||||
- Configuration values: Simple values to update configuration settings
|
||||
(e.g., `requires_safety_checker=False`)
|
||||
- Component objects: Only supports components we can extract specs using
|
||||
`ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
|
||||
ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
|
||||
method to create a new component (e.g., `guider=ComponentSpec(name="guider",
|
||||
type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
|
||||
- Configuration values: Simple values to update configuration settings (e.g.,
|
||||
`requires_safety_checker=False`)
|
||||
|
||||
Raises:
|
||||
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
|
||||
@@ -2228,9 +2247,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
|
||||
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been
|
||||
shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
|
||||
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in update_components()
|
||||
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
|
||||
update_components()
|
||||
"""
|
||||
|
||||
# extract component_specs_updates & config_specs_updates from `specs`
|
||||
@@ -2244,7 +2265,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
for name, component in passed_components.items():
|
||||
current_component_spec = self._component_specs[name]
|
||||
|
||||
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(
|
||||
component, current_component_spec.type_hint
|
||||
@@ -2255,10 +2276,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
# update _component_specs based on the new component
|
||||
new_component_spec = ComponentSpec.from_component(name, component)
|
||||
if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
|
||||
logger.warning(f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}.")
|
||||
|
||||
self._component_specs[name] = new_component_spec
|
||||
logger.warning(
|
||||
f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}."
|
||||
)
|
||||
|
||||
self._component_specs[name] = new_component_spec
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
@@ -2551,8 +2573,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
@staticmethod
|
||||
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
|
||||
"""
|
||||
Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`.
|
||||
If the `default_creation_method` is not `from_pretrained`, return None.
|
||||
Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. If
|
||||
the `default_creation_method` is not `from_pretrained`, return None.
|
||||
|
||||
This dict contains:
|
||||
- "type_hint": Tuple[str, str]
|
||||
|
||||
@@ -18,15 +18,18 @@ from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..utils import is_torch_available, logging
|
||||
import torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class InsertableDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
items = list(self.items())
|
||||
@@ -112,18 +115,18 @@ class ComponentSpec:
|
||||
@classmethod
|
||||
def from_component(cls, name: str, component: Any) -> Any:
|
||||
"""Create a ComponentSpec from a Component.
|
||||
|
||||
|
||||
Currently supports:
|
||||
- Components created with `ComponentSpec.load()` method
|
||||
- Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
|
||||
|
||||
|
||||
Args:
|
||||
name: Name of the component
|
||||
component: Component object to create spec from
|
||||
|
||||
|
||||
Returns:
|
||||
ComponentSpec object
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
|
||||
"""
|
||||
@@ -142,7 +145,9 @@ class ComponentSpec:
|
||||
elif isinstance(component, ConfigMixin):
|
||||
# warn if component was not created with `ComponentSpec`
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
logger.warning("Component was not created using `ComponentSpec`, defaulting to `from_config` creation method")
|
||||
logger.warning(
|
||||
"Component was not created using `ComponentSpec`, defaulting to `from_config` creation method"
|
||||
)
|
||||
default_creation_method = "from_config"
|
||||
else:
|
||||
# Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
|
||||
@@ -152,7 +157,6 @@ class ComponentSpec:
|
||||
f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
|
||||
)
|
||||
|
||||
|
||||
type_hint = component.__class__
|
||||
|
||||
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
|
||||
|
||||
@@ -482,7 +482,12 @@ class PushToHubMixin:
|
||||
|
||||
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
|
||||
return upload_folder(
|
||||
repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr, path_in_repo=subfolder
|
||||
repo_id=repo_id,
|
||||
folder_path=working_dir,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
path_in_repo=subfolder,
|
||||
)
|
||||
|
||||
def push_to_hub(
|
||||
|
||||
Reference in New Issue
Block a user