mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
The Modular Diffusers (#9672)
adding modular diffusers as experimental feature --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -34,9 +34,11 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"guiders": [],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"modular_pipelines": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": [],
|
||||
"schedulers": [],
|
||||
@@ -130,14 +132,29 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
"TangentialClassifierFreeGuidance",
|
||||
]
|
||||
)
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -221,6 +238,14 @@ else:
|
||||
"WanVACETransformer3DModel",
|
||||
]
|
||||
)
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"ComponentsManager",
|
||||
"ComponentSpec",
|
||||
"ModularPipeline",
|
||||
"ModularPipelineBlocks",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
"get_constant_schedule",
|
||||
"get_constant_schedule_with_warmup",
|
||||
@@ -333,6 +358,12 @@ except OptionalDependencyNotAvailable:
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"AllegroPipeline",
|
||||
@@ -545,6 +576,7 @@ else:
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -751,13 +783,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
)
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
@@ -837,6 +882,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ComponentsManager,
|
||||
ComponentSpec,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -933,6 +984,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from .pipelines import (
|
||||
AllegroPipeline,
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
|
||||
134
src/diffusers/commands/custom_blocks.py
Normal file
134
src/diffusers/commands/custom_blocks.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage example:
|
||||
TODO
|
||||
"""
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
|
||||
CONFIG = "config.json"
|
||||
|
||||
|
||||
def conversion_command_factory(args: Namespace):
|
||||
return CustomBlocksCommand(args.block_module_name, args.block_class_name)
|
||||
|
||||
|
||||
class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
conversion_parser = parser.add_parser("custom_blocks")
|
||||
conversion_parser.add_argument(
|
||||
"--block_module_name",
|
||||
type=str,
|
||||
default="block.py",
|
||||
help="Module filename in which the custom block will be implemented.",
|
||||
)
|
||||
conversion_parser.add_argument(
|
||||
"--block_class_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the custom block. If provided None, we will try to infer it.",
|
||||
)
|
||||
conversion_parser.set_defaults(func=conversion_command_factory)
|
||||
|
||||
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
|
||||
self.logger = logging.get_logger("diffusers-cli/custom_blocks")
|
||||
self.block_module_name = Path(block_module_name)
|
||||
self.block_class_name = block_class_name
|
||||
|
||||
def run(self):
|
||||
# determine the block to be saved.
|
||||
out = self._get_class_names(self.block_module_name)
|
||||
classes_found = list({cls for cls, _ in out})
|
||||
|
||||
if self.block_class_name is not None:
|
||||
child_class, parent_class = self._choose_block(out, self.block_class_name)
|
||||
if child_class is None and parent_class is None:
|
||||
raise ValueError(
|
||||
"`block_class_name` could not be retrieved. Available classes from "
|
||||
f"{self.block_module_name}:\n{classes_found}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Found classes: {classes_found} will be using {classes_found[0]}. "
|
||||
"If this needs to be changed, re-run the command specifying `block_class_name`."
|
||||
)
|
||||
child_class, parent_class = out[0][0], out[0][1]
|
||||
|
||||
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
|
||||
# the user is responsible for running it, so I guess that is safe?
|
||||
module_name = f"__dynamic__{self.block_module_name.stem}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
getattr(module, child_class)().save_pretrained(os.getcwd())
|
||||
|
||||
# or, we could create it manually.
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
if cls == chosen:
|
||||
return cls, base
|
||||
return None, None
|
||||
|
||||
def _get_class_names(self, file_path):
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
try:
|
||||
tree = ast.parse(source, filename=file_path)
|
||||
except SyntaxError as e:
|
||||
raise ValueError(f"Could not parse {file_path!r}: {e}") from e
|
||||
|
||||
results: list[tuple[str, str]] = []
|
||||
for node in tree.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
|
||||
# extract all base names for this class
|
||||
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
|
||||
|
||||
# for each allowed base that appears in the class's bases, emit a tuple
|
||||
for allowed in EXPECTED_PARENT_CLASSES:
|
||||
if allowed in base_names:
|
||||
results.append((node.name, allowed))
|
||||
|
||||
return results
|
||||
|
||||
def _get_base_name(self, node: ast.expr):
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
val = self._get_base_name(node.value)
|
||||
return f"{val}.{node.attr}" if val else node.attr
|
||||
return None
|
||||
|
||||
def _create_automap(self, parent_class, child_class):
|
||||
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
|
||||
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
|
||||
return {"auto_map": auto_map}
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from .custom_blocks import CustomBlocksCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .fp16_safetensors import FP16SafetensorsCommand
|
||||
|
||||
@@ -26,6 +27,7 @@ def main():
|
||||
# Register commands
|
||||
EnvironmentCommand.register_subcommand(commands_parser)
|
||||
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
||||
CustomBlocksCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -176,6 +176,7 @@ class ConfigMixin:
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
@@ -183,6 +184,7 @@ class ConfigMixin:
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -601,6 +603,10 @@ class ConfigMixin:
|
||||
value = value.tolist()
|
||||
elif isinstance(value, Path):
|
||||
value = value.as_posix()
|
||||
elif hasattr(value, "to_dict") and callable(value.to_dict):
|
||||
value = value.to_dict()
|
||||
elif isinstance(value, list):
|
||||
value = [to_json_saveable(v) for v in value]
|
||||
return value
|
||||
|
||||
if "quantization_config" in config_dict:
|
||||
|
||||
39
src/diffusers/guiders/__init__.py
Normal file
39
src/diffusers/guiders/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
]
|
||||
188
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
188
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_apg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
190
src/diffusers/guiders/auto_guidance.py
Normal file
190
src/diffusers/guiders/auto_guidance.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AutoGuidance(BaseGuidance):
|
||||
"""
|
||||
AutoGuidance: https://huggingface.co/papers/2406.02507
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided.
|
||||
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
dropout (`float`, *optional*):
|
||||
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
|
||||
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
dropout: Optional[float] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self.dropout = dropout
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if auto_guidance_layers is None and auto_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if auto_guidance_layers is not None and auto_guidance_config is not None:
|
||||
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
|
||||
if (dropout is None and auto_guidance_layers is not None) or (
|
||||
dropout is not None and auto_guidance_layers is None
|
||||
):
|
||||
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
|
||||
|
||||
if auto_guidance_layers is not None:
|
||||
if isinstance(auto_guidance_layers, int):
|
||||
auto_guidance_layers = [auto_guidance_layers]
|
||||
if not isinstance(auto_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
|
||||
)
|
||||
auto_guidance_config = [
|
||||
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
|
||||
]
|
||||
|
||||
if isinstance(auto_guidance_config, dict):
|
||||
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
|
||||
|
||||
if isinstance(auto_guidance_config, LayerSkipConfig):
|
||||
auto_guidance_config = [auto_guidance_config]
|
||||
|
||||
if not isinstance(auto_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
|
||||
)
|
||||
elif isinstance(next(iter(auto_guidance_config), None), dict):
|
||||
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
|
||||
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name in self._auto_guidance_hook_names:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_ag_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_ag_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_ag_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
141
src/diffusers/guiders/classifier_free_guidance.py
Normal file
141
src/diffusers/guiders/classifier_free_guidance.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
|
||||
proposes scaling and shifting the conditional distribution based on the difference between conditional and
|
||||
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
152
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
152
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
||||
|
||||
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
||||
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
||||
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
||||
quality of generated images.
|
||||
|
||||
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
zero_init_steps (`int`, defaults to `1`):
|
||||
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
zero_init_steps: int = 1,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred_cond_flat = pred_cond.flatten(1)
|
||||
pred_uncond_flat = pred_uncond.flatten(1)
|
||||
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
||||
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
||||
pred_uncond = pred_uncond * alpha
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond_dtype = cond.dtype
|
||||
cond = cond.float()
|
||||
uncond = uncond.float()
|
||||
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
scale = dot_product / squared_norm
|
||||
return scale.to(dtype=cond_dtype)
|
||||
309
src/diffusers/guiders/guider_utils.py
Normal file
309
src/diffusers/guiders/guider_utils.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import PushToHubMixin, get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
GUIDER_CONFIG_NAME = "guider_config.json"
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
r"""Base class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
config_name = GUIDER_CONFIG_NAME
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
if not (start <= stop <= 1.0):
|
||||
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
||||
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
||||
the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
||||
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
||||
conditional data identifier and the second element must be the unconditional data identifier or None.
|
||||
|
||||
Example:
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = (
|
||||
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model preparation logic.
|
||||
"""
|
||||
self._count_prepared += 1
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
|
||||
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
||||
modifications made during `prepare_models`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
if len(data) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
||||
)
|
||||
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
||||
return self.forward(**forward_inputs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_unconditional(self) -> bool:
|
||||
return not self.is_conditional
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
||||
length 2 is provided, the first element must be the conditional data identifier and the second element
|
||||
must be the unconditional data identifier or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError(
|
||||
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
||||
)
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
) -> Self:
|
||||
r"""
|
||||
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
|
||||
saved with [`~BaseGuidance.save_pretrained`].
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
||||
`huggingface-cli login`. You can also activate the special
|
||||
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
||||
firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs, commit_hash = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a guider configuration object to a directory so that it can be reloaded using the
|
||||
[`~BaseGuidance.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
271
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
271
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from ..utils import get_logger
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(BaseGuidance):
|
||||
"""
|
||||
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
|
||||
|
||||
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
|
||||
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
|
||||
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
|
||||
layers.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
|
||||
and implementation details.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
perturbed_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for perturbed attention guidance.
|
||||
perturbed_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
||||
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
||||
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
||||
If not provided, `perturbed_guidance_config` must be provided.
|
||||
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
|
||||
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
|
||||
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
|
||||
# for each model architecture.
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
perturbed_guidance_scale: float = 2.8,
|
||||
perturbed_guidance_start: float = 0.01,
|
||||
perturbed_guidance_stop: float = 0.2,
|
||||
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
||||
self.skip_layer_guidance_start = perturbed_guidance_start
|
||||
self.skip_layer_guidance_stop = perturbed_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if perturbed_guidance_config is None:
|
||||
if perturbed_guidance_layers is None:
|
||||
raise ValueError(
|
||||
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
|
||||
)
|
||||
perturbed_guidance_config = LayerSkipConfig(
|
||||
indices=perturbed_guidance_layers,
|
||||
fqn="auto",
|
||||
skip_attention=False,
|
||||
skip_attention_scores=True,
|
||||
skip_ff=False,
|
||||
)
|
||||
else:
|
||||
if perturbed_guidance_layers is not None:
|
||||
raise ValueError(
|
||||
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
|
||||
)
|
||||
|
||||
if isinstance(perturbed_guidance_config, dict):
|
||||
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
|
||||
|
||||
if isinstance(perturbed_guidance_config, LayerSkipConfig):
|
||||
perturbed_guidance_config = [perturbed_guidance_config]
|
||||
|
||||
if not isinstance(perturbed_guidance_config, list):
|
||||
raise ValueError(
|
||||
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
|
||||
)
|
||||
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
|
||||
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
|
||||
|
||||
for config in perturbed_guidance_config:
|
||||
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
|
||||
logger.warning(
|
||||
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
|
||||
"Please check your configuration. Modifying the config to match the expected values."
|
||||
)
|
||||
config.skip_attention = False
|
||||
config.skip_attention_scores = True
|
||||
config.skip_ff = False
|
||||
|
||||
self.skip_layer_config = perturbed_guidance_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
262
src/diffusers/guiders/skip_layer_guidance.py
Normal file
262
src/diffusers/guiders/skip_layer_guidance.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SkipLayerGuidance(BaseGuidance):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
||||
|
||||
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
||||
|
||||
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
||||
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
||||
batch of data, apart from the conditional and unconditional batches already used in CFG
|
||||
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
||||
based on the difference between conditional without skipping and conditional with skipping predictions.
|
||||
|
||||
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
||||
version of the model for the conditional prediction).
|
||||
|
||||
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
||||
generation quality in video diffusion models.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
||||
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_layer_guidance_layers is not None:
|
||||
if isinstance(skip_layer_guidance_layers, int):
|
||||
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
||||
if not isinstance(skip_layer_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, dict):
|
||||
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
elif isinstance(next(iter(skip_layer_config), None), dict):
|
||||
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
251
src/diffusers/guiders/smoothed_energy_guidance.py
Normal file
251
src/diffusers/guiders/smoothed_energy_guidance.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry
|
||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SmoothedEnergyGuidance(BaseGuidance):
|
||||
"""
|
||||
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
||||
future without warning or guarantee of reproducibility. This implementation assumes:
|
||||
- Generated images are square (height == width)
|
||||
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
|
||||
together such as Flux)
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
seg_guidance_scale (`float`, defaults to `3.0`):
|
||||
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
||||
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
||||
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
||||
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
||||
The threshold above which the blur is considered infinite.
|
||||
seg_guidance_start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
||||
seg_guidance_stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
||||
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
||||
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
||||
Diffusion 3.5 Medium.
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
||||
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
seg_guidance_scale: float = 2.8,
|
||||
seg_blur_sigma: float = 9999999.0,
|
||||
seg_blur_threshold_inf: float = 9999.0,
|
||||
seg_guidance_start: float = 0.0,
|
||||
seg_guidance_stop: float = 1.0,
|
||||
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
self.seg_blur_sigma = seg_blur_sigma
|
||||
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
||||
self.seg_guidance_start = seg_guidance_start
|
||||
self.seg_guidance_stop = seg_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= seg_guidance_start < 1.0):
|
||||
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
|
||||
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
||||
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
|
||||
|
||||
if seg_guidance_layers is None and seg_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
||||
)
|
||||
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
||||
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
||||
|
||||
if seg_guidance_layers is not None:
|
||||
if isinstance(seg_guidance_layers, int):
|
||||
seg_guidance_layers = [seg_guidance_layers]
|
||||
if not isinstance(seg_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
||||
)
|
||||
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
||||
|
||||
if isinstance(seg_guidance_config, dict):
|
||||
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
|
||||
|
||||
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
||||
seg_guidance_config = [seg_guidance_config]
|
||||
|
||||
if not isinstance(seg_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
||||
)
|
||||
elif isinstance(next(iter(seg_guidance_config), None), dict):
|
||||
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
|
||||
|
||||
self.seg_guidance_config = seg_guidance_config
|
||||
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
||||
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_seg: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
||||
pred = pred + self.seg_guidance_scale * shift
|
||||
elif not self._is_seg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_seg = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_seg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_seg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
143
src/diffusers/guiders/tangential_classifier_free_guidance.py
Normal file
143
src/diffusers/guiders/tangential_classifier_free_guidance.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_tcfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_tcfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_tcfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
|
||||
) -> torch.Tensor:
|
||||
cond_dtype = pred_cond.dtype
|
||||
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
||||
preds = preds.flatten(2)
|
||||
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
||||
Vh_modified = Vh.clone()
|
||||
Vh_modified[:, 1] = 0
|
||||
|
||||
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
|
||||
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
||||
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
||||
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
return pred
|
||||
@@ -20,5 +20,7 @@ if is_torch_available():
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,10 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import FeedForward, LuminaFeedForward
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
@@ -28,3 +34,10 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
if submodule_name == fqn:
|
||||
return submodule
|
||||
return None
|
||||
|
||||
254
src/diffusers/hooks/layer_skip.py
Normal file
254
src/diffusers/hooks/layer_skip.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import (
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_ATTENTION_CLASSES,
|
||||
_FEEDFORWARD_CLASSES,
|
||||
_get_submodule_from_fqn,
|
||||
)
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
|
||||
# either remove or make it serializable
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
||||
provide the correct fqn.
|
||||
skip_attention (`bool`, defaults to `True`):
|
||||
Whether to skip attention blocks.
|
||||
skip_ff (`bool`, defaults to `True`):
|
||||
Whether to skip feed-forward blocks.
|
||||
skip_attention_scores (`bool`, defaults to `False`):
|
||||
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
|
||||
projections as the output of scaled dot product attention.
|
||||
dropout (`float`, defaults to `1.0`):
|
||||
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
|
||||
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
|
||||
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
skip_ff: bool = True
|
||||
dropout: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0 <= self.dropout <= 1):
|
||||
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
|
||||
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "LayerSkipConfig":
|
||||
return LayerSkipConfig(**data)
|
||||
|
||||
|
||||
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||
value = kwargs.get("value", None)
|
||||
if value is None:
|
||||
value = args[2]
|
||||
return value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
if not math.isclose(self.dropout, 1.0):
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
if output is None and len(args) > 0:
|
||||
output = args[0]
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlockSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
if self._metadata.return_encoder_hidden_states_index is None:
|
||||
output = original_hidden_states
|
||||
else:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
output = (original_hidden_states, original_encoder_hidden_states)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.skip_attention and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
||||
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
167
src/diffusers/hooks/smoothed_energy_guidance_utils.py
Normal file
167
src/diffusers/hooks/smoothed_energy_guidance_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmoothedEnergyGuidanceConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
||||
provide the correct fqn.
|
||||
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
||||
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
|
||||
`None`, `to_q` is used by default.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
_query_proj_identifiers: List[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
|
||||
return SmoothedEnergyGuidanceConfig(**data)
|
||||
|
||||
|
||||
class SmoothedEnergyGuidanceHook(ModelHook):
|
||||
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
|
||||
super().__init__()
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_threshold_inf = blur_threshold_inf
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
|
||||
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
|
||||
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
|
||||
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
|
||||
return smoothed_output
|
||||
|
||||
|
||||
def _apply_smoothed_energy_guidance_hook(
|
||||
module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
|
||||
) -> None:
|
||||
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
if config._query_proj_identifiers is None:
|
||||
config._query_proj_identifiers = ["to_q"]
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
||||
continue
|
||||
for identifier in config._query_proj_identifiers:
|
||||
query_proj = getattr(submodule, identifier, None)
|
||||
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
|
||||
continue
|
||||
logger.debug(
|
||||
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
||||
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
|
||||
|
||||
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
|
||||
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
|
||||
"""
|
||||
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
|
||||
However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
|
||||
implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
|
||||
assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
|
||||
Smoothed Energy Guidance.
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
||||
future without warning or guarantee of reproducibility.
|
||||
"""
|
||||
assert query.ndim == 3
|
||||
|
||||
is_inf = sigma > sigma_threshold_inf
|
||||
batch_size, seq_len, embed_dim = query.shape
|
||||
|
||||
seq_len_sqrt = int(math.sqrt(seq_len))
|
||||
num_square_tokens = seq_len_sqrt * seq_len_sqrt
|
||||
query_slice = query[:, :num_square_tokens, :]
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
||||
|
||||
if is_inf:
|
||||
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
||||
kernel_size_half = (kernel_size - 1) / 2
|
||||
|
||||
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
kernel1d = pdf / pdf.sum()
|
||||
kernel1d = kernel1d.to(query)
|
||||
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
|
||||
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
query_slice = F.pad(query_slice, padding, mode="reflect")
|
||||
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
|
||||
else:
|
||||
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
|
||||
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||
|
||||
return query
|
||||
@@ -84,6 +84,7 @@ if is_torch_available():
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
"ModularIPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
@@ -101,6 +102,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
ModularIPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
|
||||
@@ -354,6 +354,256 @@ class IPAdapterMixin:
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class ModularIPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
|
||||
# remove hidden encoder
|
||||
if self.unet is None:
|
||||
return
|
||||
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -331,6 +330,8 @@ def _load_lora_into_text_encoder(
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
@@ -442,6 +443,8 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
|
||||
"""
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -164,6 +163,8 @@ class PeftAdapterMixin:
|
||||
from peft import inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -695,6 +696,7 @@ class PeftAdapterMixin:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
|
||||
recurse_remove_peft_layers(self)
|
||||
|
||||
@@ -22,7 +22,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
@@ -132,6 +131,8 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
```
|
||||
"""
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
|
||||
84
src/diffusers/modular_pipelines/__init__.py
Normal file
84
src/diffusers/modular_pipelines/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
"LoopSequentialPipelineBlocks",
|
||||
"PipelineState",
|
||||
"BlockState",
|
||||
]
|
||||
_import_structure["modular_pipeline_utils"] = [
|
||||
"ComponentSpec",
|
||||
"ConfigSpec",
|
||||
"InputParam",
|
||||
"OutputParam",
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from .modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
)
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
1046
src/diffusers/modular_pipelines/components_manager.py
Normal file
1046
src/diffusers/modular_pipelines/components_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
2827
src/diffusers/modular_pipelines/modular_pipeline.py
Normal file
2827
src/diffusers/modular_pipelines/modular_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
671
src/diffusers/modular_pipelines/modular_pipeline_utils.py
Normal file
671
src/diffusers/modular_pipelines/modular_pipeline_utils.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class InsertableDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
items = list(self.items())
|
||||
|
||||
# Remove key if it already exists to avoid duplicates
|
||||
items = [(k, v) for k, v in items if k != key]
|
||||
|
||||
# Insert at the specified index
|
||||
items.insert(index, (key, value))
|
||||
|
||||
# Clear and update self
|
||||
self.clear()
|
||||
self.update(items)
|
||||
|
||||
# Return self for method chaining
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
if not self:
|
||||
return "InsertableDict()"
|
||||
|
||||
items = []
|
||||
for i, (key, value) in enumerate(self.items()):
|
||||
if isinstance(value, type):
|
||||
# For classes, show class name and <class ...>
|
||||
obj_repr = f"<class '{value.__module__}.{value.__name__}'>"
|
||||
else:
|
||||
# For objects (instances) and other types, show class name and module
|
||||
obj_repr = f"<obj '{value.__class__.__module__}.{value.__class__.__name__}'>"
|
||||
items.append(f"{i}: ({repr(key)}, {obj_repr})")
|
||||
|
||||
return "InsertableDict([\n " + ",\n ".join(items) + "\n])"
|
||||
|
||||
|
||||
# YiYi TODO:
|
||||
# 1. validate the dataclass fields
|
||||
# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained()
|
||||
@dataclass
|
||||
class ComponentSpec:
|
||||
"""Specification for a pipeline component.
|
||||
|
||||
A component can be created in two ways:
|
||||
1. From scratch using __init__ with a config dict
|
||||
2. using `from_pretrained`
|
||||
|
||||
Attributes:
|
||||
name: Name of the component
|
||||
type_hint: Type of the component (e.g. UNet2DConditionModel)
|
||||
description: Optional description of the component
|
||||
config: Optional config dict for __init__ creation
|
||||
repo: Optional repo path for from_pretrained creation
|
||||
subfolder: Optional subfolder in repo
|
||||
variant: Optional variant in repo
|
||||
revision: Optional revision in repo
|
||||
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
type_hint: Optional[Type] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
def __hash__(self):
|
||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||
return hash((self.name, self.load_id, self.default_creation_method))
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Compare ComponentSpec objects based on name and load_id."""
|
||||
if not isinstance(other, ComponentSpec):
|
||||
return False
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.load_id == other.load_id
|
||||
and self.default_creation_method == other.default_creation_method
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_component(cls, name: str, component: Any) -> Any:
|
||||
"""Create a ComponentSpec from a Component.
|
||||
|
||||
Currently supports:
|
||||
- Components created with `ComponentSpec.load()` method
|
||||
- Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
|
||||
|
||||
Args:
|
||||
name: Name of the component
|
||||
component: Component object to create spec from
|
||||
|
||||
Returns:
|
||||
ComponentSpec object
|
||||
|
||||
Raises:
|
||||
ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
|
||||
"""
|
||||
|
||||
# Check if component was created with ComponentSpec.load()
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
# component has a usable load_id -> from_pretrained, no warning needed
|
||||
default_creation_method = "from_pretrained"
|
||||
else:
|
||||
# Component doesn't have a usable load_id, check if it's a nn.Module
|
||||
if isinstance(component, torch.nn.Module):
|
||||
raise ValueError(
|
||||
"Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
|
||||
)
|
||||
# ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
|
||||
elif isinstance(component, ConfigMixin):
|
||||
# warn if component was not created with `ComponentSpec`
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
logger.warning(
|
||||
"Component was not created using `ComponentSpec`, defaulting to `from_config` creation method"
|
||||
)
|
||||
default_creation_method = "from_config"
|
||||
else:
|
||||
# Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
|
||||
raise ValueError(
|
||||
f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
|
||||
f" - components created with `ComponentSpec.load()` method"
|
||||
f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
|
||||
)
|
||||
|
||||
type_hint = component.__class__
|
||||
|
||||
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
|
||||
config = component.config
|
||||
else:
|
||||
config = None
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||
else:
|
||||
load_spec = {}
|
||||
|
||||
return cls(
|
||||
name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def loading_fields(cls) -> List[str]:
|
||||
"""
|
||||
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
"""
|
||||
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
||||
segments).
|
||||
"""
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Decode a load_id string back into a dictionary of loading fields and values.
|
||||
|
||||
Args:
|
||||
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
|
||||
where None values are represented as "null"
|
||||
|
||||
Returns:
|
||||
Dict mapping loading field names to their values. e.g. {
|
||||
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
|
||||
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
|
||||
component not created with `load` method).
|
||||
"""
|
||||
|
||||
# Get all loading fields in order
|
||||
loading_fields = cls.loading_fields()
|
||||
result = {f: None for f in loading_fields}
|
||||
|
||||
if load_id == "null":
|
||||
return result
|
||||
|
||||
# Split the load_id
|
||||
parts = load_id.split("|")
|
||||
|
||||
# Map parts to loading fields by position
|
||||
for i, part in enumerate(parts):
|
||||
if i < len(loading_fields):
|
||||
# Convert "null" string back to None
|
||||
result[loading_fields[i]] = None if part == "null" else part
|
||||
|
||||
return result
|
||||
|
||||
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
|
||||
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
|
||||
# the config info is lost in the process
|
||||
# remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method
|
||||
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
||||
"""Create component using from_config with config."""
|
||||
|
||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||
raise ValueError("`type_hint` is required when using from_config creation method.")
|
||||
|
||||
config = config or self.config or {}
|
||||
|
||||
if issubclass(self.type_hint, ConfigMixin):
|
||||
component = self.type_hint.from_config(config, **kwargs)
|
||||
else:
|
||||
signature_params = inspect.signature(self.type_hint.__init__).parameters
|
||||
init_kwargs = {}
|
||||
for k, v in config.items():
|
||||
if k in signature_params:
|
||||
init_kwargs[k] = v
|
||||
for k, v in kwargs.items():
|
||||
if k in signature_params:
|
||||
init_kwargs[k] = v
|
||||
component = self.type_hint(**init_kwargs)
|
||||
|
||||
component._diffusers_load_id = "null"
|
||||
if hasattr(component, "config"):
|
||||
self.config = component.config
|
||||
|
||||
return component
|
||||
|
||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||
def load(self, **kwargs) -> Any:
|
||||
"""Load component using from_pretrained."""
|
||||
|
||||
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||
repo = load_kwargs.pop("repo", None)
|
||||
if repo is None:
|
||||
raise ValueError(
|
||||
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
|
||||
)
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
|
||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||
# update type_hint if AutoModel load successfully
|
||||
self.type_hint = component.__class__
|
||||
else:
|
||||
try:
|
||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||
|
||||
self.repo = repo
|
||||
for k, v in load_kwargs.items():
|
||||
setattr(self, k, v)
|
||||
component._diffusers_load_id = self.load_id
|
||||
|
||||
return component
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
|
||||
name: str
|
||||
default: Any
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
|
||||
# however some fields are not relevant for intermediate_inputs
|
||||
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
|
||||
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
|
||||
# -> should we use different class for inputs and intermediate_inputs?
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
|
||||
name: str = None
|
||||
type_hint: Any = None
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
"""Specification for an output parameter."""
|
||||
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
)
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
"""
|
||||
Format input parameters into a string representation, with required params first followed by optional ones.
|
||||
|
||||
Args:
|
||||
inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
|
||||
|
||||
Returns:
|
||||
str: Formatted string of input parameters
|
||||
|
||||
Example:
|
||||
>>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ...
|
||||
InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps",
|
||||
required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5,
|
||||
num_inference_steps=50'
|
||||
"""
|
||||
required_inputs = [param for param in inputs if param.required]
|
||||
optional_inputs = [param for param in inputs if not param.required]
|
||||
|
||||
required_str = ", ".join(param.name for param in required_inputs)
|
||||
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
|
||||
|
||||
inputs_str = required_str
|
||||
if optional_str:
|
||||
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
|
||||
|
||||
return inputs_str
|
||||
|
||||
|
||||
def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
|
||||
"""
|
||||
Formats intermediate inputs and outputs of a block into a string representation.
|
||||
|
||||
Args:
|
||||
intermediate_inputs: List of intermediate input parameters
|
||||
required_intermediate_inputs: List of required intermediate input names
|
||||
intermediate_outputs: List of intermediate output parameters
|
||||
|
||||
Returns:
|
||||
str: Formatted string like:
|
||||
Intermediates:
|
||||
- inputs: Required(latents), dtype
|
||||
- modified: latents # variables that appear in both inputs and outputs
|
||||
- outputs: images # new outputs only
|
||||
"""
|
||||
# Handle inputs
|
||||
input_parts = []
|
||||
for inp in intermediate_inputs:
|
||||
if inp.name in required_intermediate_inputs:
|
||||
input_parts.append(f"Required({inp.name})")
|
||||
else:
|
||||
if inp.name is None and inp.kwargs_type is not None:
|
||||
inp_name = "*_" + inp.kwargs_type
|
||||
else:
|
||||
inp_name = inp.name
|
||||
input_parts.append(inp_name)
|
||||
|
||||
# Handle modified variables (appear in both inputs and outputs)
|
||||
inputs_set = {inp.name for inp in intermediate_inputs}
|
||||
modified_parts = []
|
||||
new_output_parts = []
|
||||
|
||||
for out in intermediate_outputs:
|
||||
if out.name in inputs_set:
|
||||
modified_parts.append(out.name)
|
||||
else:
|
||||
new_output_parts.append(out.name)
|
||||
|
||||
result = []
|
||||
if input_parts:
|
||||
result.append(f" - inputs: {', '.join(input_parts)}")
|
||||
if modified_parts:
|
||||
result.append(f" - modified: {', '.join(modified_parts)}")
|
||||
if new_output_parts:
|
||||
result.append(f" - outputs: {', '.join(new_output_parts)}")
|
||||
|
||||
return "\n".join(result) if result else " (none)"
|
||||
|
||||
|
||||
def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
"""Format a list of InputParam or OutputParam objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
params: List of InputParam or OutputParam objects to format
|
||||
header: Header text to use (e.g. "Args" or "Returns")
|
||||
indent_level: Number of spaces to indent each parameter line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all parameters
|
||||
"""
|
||||
if not params:
|
||||
return ""
|
||||
|
||||
base_indent = " " * indent_level
|
||||
param_indent = " " * (indent_level + 4)
|
||||
desc_indent = " " * (indent_level + 8)
|
||||
formatted_params = []
|
||||
|
||||
def get_type_str(type_hint):
|
||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
|
||||
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
|
||||
return f"Union[{', '.join(types)}]"
|
||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||
|
||||
def wrap_text(text, indent, max_length):
|
||||
"""Wrap text while preserving markdown links and maintaining indentation."""
|
||||
words = text.split()
|
||||
lines = []
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
word_length = len(word) + (1 if current_line else 0)
|
||||
|
||||
if current_line and current_length + word_length > max_length:
|
||||
lines.append(" ".join(current_line))
|
||||
current_line = [word]
|
||||
current_length = len(word)
|
||||
else:
|
||||
current_line.append(word)
|
||||
current_length += word_length
|
||||
|
||||
if current_line:
|
||||
lines.append(" ".join(current_line))
|
||||
|
||||
return f"\n{indent}".join(lines)
|
||||
|
||||
# Add the header
|
||||
formatted_params.append(f"{base_indent}{header}:")
|
||||
|
||||
for param in params:
|
||||
# Format parameter name and type
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
# YiYi Notes: remove this line if we remove kwargs_type
|
||||
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
|
||||
param_str = f"{param_indent}{name} (`{type_str}`"
|
||||
|
||||
# Add optional tag and default value if parameter is an InputParam and optional
|
||||
if hasattr(param, "required"):
|
||||
if not param.required:
|
||||
param_str += ", *optional*"
|
||||
if param.default is not None:
|
||||
param_str += f", defaults to {param.default}"
|
||||
param_str += "):"
|
||||
|
||||
# Add description on a new line with additional indentation and wrapping
|
||||
if param.description:
|
||||
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
|
||||
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
||||
param_str += f"\n{desc_indent}{wrapped_desc}"
|
||||
|
||||
formatted_params.append(param_str)
|
||||
|
||||
return "\n\n".join(formatted_params)
|
||||
|
||||
|
||||
def format_input_params(input_params, indent_level=4, max_line_length=115):
|
||||
"""Format a list of InputParam objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
input_params: List of InputParam objects to format
|
||||
indent_level: Number of spaces to indent each parameter line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all input parameters
|
||||
"""
|
||||
return format_params(input_params, "Inputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_output_params(output_params, indent_level=4, max_line_length=115):
|
||||
"""Format a list of OutputParam objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
output_params: List of OutputParam objects to format
|
||||
indent_level: Number of spaces to indent each parameter line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all output parameters
|
||||
"""
|
||||
return format_params(output_params, "Outputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ComponentSpec objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
components: List of ComponentSpec objects to format
|
||||
indent_level: Number of spaces to indent each component line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
add_empty_lines: Whether to add empty lines between components (default: True)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all components
|
||||
"""
|
||||
if not components:
|
||||
return ""
|
||||
|
||||
base_indent = " " * indent_level
|
||||
component_indent = " " * (indent_level + 4)
|
||||
formatted_components = []
|
||||
|
||||
# Add the header
|
||||
formatted_components.append(f"{base_indent}Components:")
|
||||
if add_empty_lines:
|
||||
formatted_components.append("")
|
||||
|
||||
# Add each component with optional empty lines between them
|
||||
for i, component in enumerate(components):
|
||||
# Get type name, handling special cases
|
||||
type_name = (
|
||||
component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
|
||||
)
|
||||
|
||||
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
|
||||
if component.description:
|
||||
component_desc += f": {component.description}"
|
||||
|
||||
# Get the loading fields dynamically
|
||||
loading_field_values = []
|
||||
for field_name in component.loading_fields():
|
||||
field_value = getattr(component, field_name)
|
||||
if field_value is not None:
|
||||
loading_field_values.append(f"{field_name}={field_value}")
|
||||
|
||||
# Add loading field information if available
|
||||
if loading_field_values:
|
||||
component_desc += f" [{', '.join(loading_field_values)}]"
|
||||
|
||||
formatted_components.append(component_desc)
|
||||
|
||||
# Add an empty line after each component except the last one
|
||||
if add_empty_lines and i < len(components) - 1:
|
||||
formatted_components.append("")
|
||||
|
||||
return "\n".join(formatted_components)
|
||||
|
||||
|
||||
def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ConfigSpec objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
configs: List of ConfigSpec objects to format
|
||||
indent_level: Number of spaces to indent each config line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
add_empty_lines: Whether to add empty lines between configs (default: True)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all configs
|
||||
"""
|
||||
if not configs:
|
||||
return ""
|
||||
|
||||
base_indent = " " * indent_level
|
||||
config_indent = " " * (indent_level + 4)
|
||||
formatted_configs = []
|
||||
|
||||
# Add the header
|
||||
formatted_configs.append(f"{base_indent}Configs:")
|
||||
if add_empty_lines:
|
||||
formatted_configs.append("")
|
||||
|
||||
# Add each config with optional empty lines between them
|
||||
for i, config in enumerate(configs):
|
||||
config_desc = f"{config_indent}{config.name} (default: {config.default})"
|
||||
if config.description:
|
||||
config_desc += f": {config.description}"
|
||||
formatted_configs.append(config_desc)
|
||||
|
||||
# Add an empty line after each config except the last one
|
||||
if add_empty_lines and i < len(configs) - 1:
|
||||
formatted_configs.append("")
|
||||
|
||||
return "\n".join(formatted_configs)
|
||||
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
expected_components=None,
|
||||
expected_configs=None,
|
||||
):
|
||||
"""
|
||||
Generates a formatted documentation string describing the pipeline block's parameters and structure.
|
||||
|
||||
Args:
|
||||
inputs: List of input parameters
|
||||
intermediate_inputs: List of intermediate input parameters
|
||||
outputs: List of output parameters
|
||||
description (str, *optional*): Description of the block
|
||||
class_name (str, *optional*): Name of the class to include in the documentation
|
||||
expected_components (List[ComponentSpec], *optional*): List of expected components
|
||||
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing information about components, configs, call parameters,
|
||||
intermediate inputs/outputs, and final outputs.
|
||||
"""
|
||||
output = ""
|
||||
|
||||
# Add class name if provided
|
||||
if class_name:
|
||||
output += f"class {class_name}\n\n"
|
||||
|
||||
# Add description
|
||||
if description:
|
||||
desc_lines = description.strip().split("\n")
|
||||
aligned_desc = "\n".join(" " + line for line in desc_lines)
|
||||
output += aligned_desc + "\n\n"
|
||||
|
||||
# Add components section if provided
|
||||
if expected_components and len(expected_components) > 0:
|
||||
components_str = format_components(expected_components, indent_level=2)
|
||||
output += components_str + "\n\n"
|
||||
|
||||
# Add configs section if provided
|
||||
if expected_configs and len(expected_configs) > 0:
|
||||
configs_str = format_configs(expected_configs, indent_level=2)
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
665
src/diffusers/modular_pipelines/node_utils.py
Normal file
665
src/diffusers/modular_pipelines/node_utils.py
Normal file
@@ -0,0 +1,665 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..image_processor import PipelineImageInput
|
||||
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||
SDXL_INPUTS_SCHEMA = {
|
||||
"prompt": InputParam(
|
||||
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
|
||||
),
|
||||
"prompt_2": InputParam(
|
||||
"prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
|
||||
),
|
||||
"negative_prompt": InputParam(
|
||||
"negative_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts not to guide the image generation",
|
||||
),
|
||||
"negative_prompt_2": InputParam(
|
||||
"negative_prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The negative prompt or prompts for text_encoder_2",
|
||||
),
|
||||
"cross_attention_kwargs": InputParam(
|
||||
"cross_attention_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="Kwargs dictionary passed to the AttentionProcessor",
|
||||
),
|
||||
"clip_skip": InputParam(
|
||||
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
|
||||
),
|
||||
"image": InputParam(
|
||||
"image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to modify for img2img or inpainting",
|
||||
),
|
||||
"mask_image": InputParam(
|
||||
"mask_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Mask image for inpainting, white pixels will be repainted",
|
||||
),
|
||||
"generator": InputParam(
|
||||
"generator",
|
||||
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
|
||||
description="Generator(s) for deterministic generation",
|
||||
),
|
||||
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||
"num_images_per_prompt": InputParam(
|
||||
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
|
||||
),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
|
||||
),
|
||||
"timesteps": InputParam(
|
||||
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
|
||||
),
|
||||
"sigmas": InputParam(
|
||||
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
|
||||
),
|
||||
"denoising_end": InputParam(
|
||||
"denoising_end",
|
||||
type_hint=Optional[float],
|
||||
description="Fraction of denoising process to complete before termination",
|
||||
),
|
||||
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||
"strength": InputParam(
|
||||
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
|
||||
),
|
||||
"denoising_start": InputParam(
|
||||
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
|
||||
),
|
||||
"padding_mask_crop": InputParam(
|
||||
"padding_mask_crop",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Size of margin in crop for image and mask",
|
||||
),
|
||||
"original_size": InputParam(
|
||||
"original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Original size of the image for SDXL's micro-conditioning",
|
||||
),
|
||||
"target_size": InputParam(
|
||||
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
|
||||
),
|
||||
"negative_original_size": InputParam(
|
||||
"negative_original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on image resolution",
|
||||
),
|
||||
"negative_target_size": InputParam(
|
||||
"negative_target_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on target resolution",
|
||||
),
|
||||
"crops_coords_top_left": InputParam(
|
||||
"crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Top-left coordinates for SDXL's micro-conditioning",
|
||||
),
|
||||
"negative_crops_coords_top_left": InputParam(
|
||||
"negative_crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Negative conditioning crop coordinates",
|
||||
),
|
||||
"aesthetic_score": InputParam(
|
||||
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
|
||||
),
|
||||
"negative_aesthetic_score": InputParam(
|
||||
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
|
||||
),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"output_type": InputParam(
|
||||
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
|
||||
),
|
||||
"ip_adapter_image": InputParam(
|
||||
"ip_adapter_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Image(s) to be used as IP adapter",
|
||||
),
|
||||
"control_image": InputParam(
|
||||
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
|
||||
),
|
||||
"control_guidance_start": InputParam(
|
||||
"control_guidance_start",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=0.0,
|
||||
description="When ControlNet starts applying",
|
||||
),
|
||||
"control_guidance_end": InputParam(
|
||||
"control_guidance_end",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="When ControlNet stops applying",
|
||||
),
|
||||
"controlnet_conditioning_scale": InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="Scale factor for ControlNet outputs",
|
||||
),
|
||||
"guess_mode": InputParam(
|
||||
"guess_mode",
|
||||
type_hint=bool,
|
||||
default=False,
|
||||
description="Enables ControlNet encoder to recognize input without prompts",
|
||||
),
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
required=True,
|
||||
description="Text embeddings used to guide image generation",
|
||||
),
|
||||
"negative_prompt_embeds": InputParam(
|
||||
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
||||
),
|
||||
"pooled_prompt_embeds": InputParam(
|
||||
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
|
||||
),
|
||||
"negative_pooled_prompt_embeds": InputParam(
|
||||
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
||||
),
|
||||
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
"image_latents": InputParam(
|
||||
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
|
||||
),
|
||||
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||
"masked_image_latents": InputParam(
|
||||
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
||||
),
|
||||
"add_time_ids": InputParam(
|
||||
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
|
||||
),
|
||||
"negative_add_time_ids": InputParam(
|
||||
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
||||
),
|
||||
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"ip_adapter_embeds": InputParam(
|
||||
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
||||
),
|
||||
"negative_ip_adapter_embeds": InputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative image embeddings for IP-Adapter",
|
||||
),
|
||||
"images": InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
required=True,
|
||||
description="Generated images",
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
||||
|
||||
|
||||
DEFAULT_PARAM_MAPS = {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"default": "a bear sitting in a chair drinking a milkshake",
|
||||
"display": "textarea",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"label": "Negative Prompt",
|
||||
"type": "string",
|
||||
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
"display": "textarea",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"label": "Steps",
|
||||
"type": "int",
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 1000,
|
||||
},
|
||||
"seed": {
|
||||
"label": "Seed",
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"display": "random",
|
||||
},
|
||||
"width": {
|
||||
"label": "Width",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"height": {
|
||||
"label": "Height",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"images": {
|
||||
"label": "Images",
|
||||
"type": "image",
|
||||
"display": "output",
|
||||
},
|
||||
"image": {
|
||||
"label": "Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TYPE_MAPS = {
|
||||
"int": {
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
},
|
||||
"float": {
|
||||
"type": "float",
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
},
|
||||
"str": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
"bool": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
},
|
||||
"image": {
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
||||
DEFAULT_CATEGORY = "Modular Diffusers"
|
||||
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
||||
DEFAULT_PARAMS_GROUPS_KEYS = {
|
||||
"text_encoders": ["text_encoder", "tokenizer"],
|
||||
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
||||
"prompt_embeddings": ["prompt_embeds"],
|
||||
}
|
||||
|
||||
|
||||
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||
"""
|
||||
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
|
||||
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
||||
"""
|
||||
if name is None:
|
||||
return None
|
||||
for group_name, group_keys in group_params_keys.items():
|
||||
for group_key in group_keys:
|
||||
if group_key in name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
class ModularNode(ConfigMixin):
|
||||
"""
|
||||
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
|
||||
around a ModularPipelineBlocks object.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
config_name = "node_config.json"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
return cls(blocks, **kwargs)
|
||||
|
||||
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||
self.blocks = blocks
|
||||
|
||||
if label is None:
|
||||
label = self.blocks.__class__.__name__
|
||||
# blocks param name -> mellon param name
|
||||
self.name_mapping = {}
|
||||
|
||||
input_params = {}
|
||||
# pass or create a default param dict for each input
|
||||
# e.g. for prompt,
|
||||
# prompt = {
|
||||
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
|
||||
for inp in inputs:
|
||||
param = kwargs.pop(inp.name, None)
|
||||
if param:
|
||||
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
||||
input_params[inp.name] = param
|
||||
mellon_name = param.pop("name", inp.name)
|
||||
if mellon_name != inp.name:
|
||||
self.name_mapping[inp.name] = mellon_name
|
||||
continue
|
||||
|
||||
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
continue
|
||||
|
||||
if inp.name in DEFAULT_PARAM_MAPS:
|
||||
# first check if it's in the default param map, if so, directly use that
|
||||
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||
elif get_group_name(inp.name):
|
||||
param = get_group_name(inp.name)
|
||||
if inp.name not in self.name_mapping:
|
||||
self.name_mapping[inp.name] = param
|
||||
else:
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# 1. use the type hint to determine the type
|
||||
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||
if inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).lower()
|
||||
else:
|
||||
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
||||
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
||||
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
||||
if type_key in type_str:
|
||||
param = type_param.copy()
|
||||
param["label"] = inp.name
|
||||
param["display"] = "input"
|
||||
break
|
||||
else:
|
||||
param = inp.name
|
||||
# add the param dict to the inp_params dict
|
||||
input_params[inp.name] = param
|
||||
|
||||
component_params = {}
|
||||
for comp in self.blocks.expected_components:
|
||||
param = kwargs.pop(comp.name, None)
|
||||
if param:
|
||||
component_params[comp.name] = param
|
||||
mellon_name = param.pop("name", comp.name)
|
||||
if mellon_name != comp.name:
|
||||
self.name_mapping[comp.name] = mellon_name
|
||||
continue
|
||||
|
||||
to_exclude = False
|
||||
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
||||
if exclude_key in comp.name:
|
||||
to_exclude = True
|
||||
break
|
||||
if to_exclude:
|
||||
continue
|
||||
|
||||
if get_group_name(comp.name):
|
||||
param = get_group_name(comp.name)
|
||||
if comp.name not in self.name_mapping:
|
||||
self.name_mapping[comp.name] = param
|
||||
elif comp.name in DEFAULT_MODEL_KEYS:
|
||||
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
||||
else:
|
||||
param = comp.name
|
||||
# add the param dict to the model_params dict
|
||||
component_params[comp.name] = param
|
||||
|
||||
output_params = {}
|
||||
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
|
||||
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediate_outputs
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
if param:
|
||||
output_params[out.name] = param
|
||||
mellon_name = param.pop("name", out.name)
|
||||
if mellon_name != out.name:
|
||||
self.name_mapping[out.name] = mellon_name
|
||||
continue
|
||||
|
||||
if out.name in DEFAULT_PARAM_MAPS:
|
||||
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
||||
param["display"] = "output"
|
||||
else:
|
||||
group_name = get_group_name(out.name)
|
||||
if group_name:
|
||||
param = group_name
|
||||
if out.name not in self.name_mapping:
|
||||
self.name_mapping[out.name] = param
|
||||
else:
|
||||
param = out.name
|
||||
# add the param dict to the outputs dict
|
||||
output_params[out.name] = param
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused kwargs: {kwargs}")
|
||||
|
||||
register_dict = {
|
||||
"category": category,
|
||||
"label": label,
|
||||
"input_params": input_params,
|
||||
"component_params": component_params,
|
||||
"output_params": output_params,
|
||||
"name_mapping": self.name_mapping,
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def setup(self, components_manager, collection=None):
|
||||
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
|
||||
self._components_manager = components_manager
|
||||
|
||||
@property
|
||||
def mellon_config(self):
|
||||
return self._convert_to_mellon_config()
|
||||
|
||||
def _convert_to_mellon_config(self):
|
||||
node = {}
|
||||
node["label"] = self.config.label
|
||||
node["category"] = self.config.category
|
||||
|
||||
node_param = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
if inp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[inp_name]
|
||||
else:
|
||||
mellon_name = inp_name
|
||||
if isinstance(inp_param, str):
|
||||
param = {
|
||||
"label": inp_param,
|
||||
"type": inp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = inp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
if comp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[comp_name]
|
||||
else:
|
||||
mellon_name = comp_name
|
||||
if isinstance(comp_param, str):
|
||||
param = {
|
||||
"label": comp_param,
|
||||
"type": comp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = comp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||
|
||||
for out_name, out_param in self.config.output_params.items():
|
||||
if out_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[out_name]
|
||||
else:
|
||||
mellon_name = out_name
|
||||
if isinstance(out_param, str):
|
||||
param = {
|
||||
"label": out_param,
|
||||
"type": out_param,
|
||||
"display": "output",
|
||||
}
|
||||
else:
|
||||
param = out_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
||||
node["params"] = node_param
|
||||
return node
|
||||
|
||||
def save_mellon_config(self, file_path):
|
||||
"""
|
||||
Save the Mellon configuration to a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path where the JSON file will be saved
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved config file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# Create a combined dictionary with module definition and name mapping
|
||||
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
|
||||
|
||||
# Save the config to file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_mellon_config(cls, file_path):
|
||||
"""
|
||||
Load a Mellon configuration from a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path to the JSON file containing Mellon config
|
||||
|
||||
Returns:
|
||||
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Mellon config loaded from {file_path}")
|
||||
|
||||
return config
|
||||
|
||||
def process_inputs(self, **kwargs):
|
||||
params_components = {}
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
logger.debug(f"component: {comp_name}")
|
||||
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
||||
if mellon_comp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
||||
comp = kwargs[mellon_comp_name].pop(comp_name)
|
||||
else:
|
||||
comp = kwargs.pop(mellon_comp_name)
|
||||
if comp:
|
||||
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||
|
||||
params_run = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
logger.debug(f"input: {inp_name}")
|
||||
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
||||
if mellon_inp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
||||
inp = kwargs[mellon_inp_name].pop(inp_name)
|
||||
else:
|
||||
inp = kwargs.pop(mellon_inp_name)
|
||||
if inp is not None:
|
||||
params_run[inp_name] = inp
|
||||
|
||||
return_output_names = list(self.config.output_params.keys())
|
||||
|
||||
return params_components, params_run, return_output_names
|
||||
|
||||
def execute(self, **kwargs):
|
||||
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||
|
||||
self.pipeline.update_components(**params_components)
|
||||
output = self.pipeline(**params_run, output=return_output_names)
|
||||
return output
|
||||
@@ -0,0 +1,77 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"IP_ADAPTER_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLAutoControlnetStep",
|
||||
"StableDiffusionXLAutoDecodeStep",
|
||||
"StableDiffusionXLAutoIPAdapterStep",
|
||||
"StableDiffusionXLAutoVaeEncoderStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import (
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
IP_ADAPTER_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLAutoControlnetStep,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
)
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
File diff suppressed because it is too large
Load Diff
218
src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
Normal file
218
src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
|
||||
def upcast_vae(components):
|
||||
dtype = components.vae.dtype
|
||||
components.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
components.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
components.vae.post_quant_conv.to(dtype)
|
||||
components.vae.decoder.conv_in.to(dtype)
|
||||
components.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if not block_state.output_type == "latent":
|
||||
latents = block_state.latents
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
|
||||
|
||||
if block_state.needs_upcasting:
|
||||
self.upcast_vae(components)
|
||||
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
|
||||
elif latents.dtype != components.vae.dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
components.vae = components.vae.to(latents.dtype)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
block_state.has_latents_mean = (
|
||||
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
|
||||
)
|
||||
block_state.has_latents_std = (
|
||||
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
|
||||
)
|
||||
if block_state.has_latents_mean and block_state.has_latents_std:
|
||||
block_state.latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
block_state.latents_std = (
|
||||
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = (
|
||||
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
|
||||
)
|
||||
else:
|
||||
latents = latents / components.vae.config.scaling_factor
|
||||
|
||||
block_state.images = components.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if block_state.needs_upcasting:
|
||||
components.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
block_state.images = block_state.latents
|
||||
|
||||
# apply watermark if available
|
||||
if hasattr(components, "watermark") and components.watermark is not None:
|
||||
block_state.images = components.watermark.apply_watermark(block_state.images)
|
||||
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"A post-processing step that overlays the mask on the image (inpainting task only).\n"
|
||||
+ "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("image"),
|
||||
InputParam("mask_image"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images from the decode step",
|
||||
),
|
||||
InputParam(
|
||||
"crops_coords",
|
||||
type_hint=Tuple[int, int],
|
||||
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
|
||||
block_state.images = [
|
||||
components.image_processor.apply_overlay(
|
||||
block_state.mask_image, block_state.image, i, block_state.crops_coords
|
||||
)
|
||||
for i in block_state.images
|
||||
]
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
791
src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
Normal file
791
src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
Normal file
@@ -0,0 +1,791 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepare the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object"
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"mask",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"masked_image_latents",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
num_channels_unet = components.num_channels_unet
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
if block_state.mask is None or block_state.masked_image_latents is None:
|
||||
raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
|
||||
num_channels_latents = block_state.latents.shape[1]
|
||||
num_channels_mask = block_state.mask.shape[1]
|
||||
num_channels_masked_image = block_state.masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects"
|
||||
f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
" `components.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
if components.num_channels_unet == 9:
|
||||
block_state.scaled_latents = torch.cat(
|
||||
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
|
||||
)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoise the latents with guidance. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"timestep_cond",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="guider_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, "
|
||||
"add_time_ids/negative_add_time_ids, "
|
||||
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
|
||||
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
|
||||
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"time_ids": ("add_time_ids", "negative_add_time_ids"),
|
||||
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.unet)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.unet(
|
||||
block_state.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=block_state.timestep_cond,
|
||||
cross_attention_kwargs=block_state.cross_attention_kwargs,
|
||||
added_cond_kwargs=cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.unet)
|
||||
|
||||
# Perform guidance
|
||||
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that denoise the latents with guidance (with controlnet). "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"conditioning_scale",
|
||||
type_hint=float,
|
||||
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"guess_mode",
|
||||
required=True,
|
||||
type_hint=bool,
|
||||
description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_keep",
|
||||
required=True,
|
||||
type_hint=List[float],
|
||||
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"timestep_cond",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="guider_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, "
|
||||
"add_time_ids/negative_add_time_ids, "
|
||||
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
|
||||
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
|
||||
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="controlnet_kwargs",
|
||||
description=(
|
||||
"additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )"
|
||||
"please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
|
||||
accepted_kwargs = set(inspect.signature(func).parameters.keys())
|
||||
extra_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if key in accepted_kwargs and key not in exclude_kwargs:
|
||||
extra_kwargs[key] = value
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
extra_controlnet_kwargs = self.prepare_extra_kwargs(
|
||||
components.controlnet.forward, **block_state.controlnet_kwargs
|
||||
)
|
||||
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_input_fields = {
|
||||
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"time_ids": ("add_time_ids", "negative_add_time_ids"),
|
||||
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
|
||||
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
}
|
||||
|
||||
# cond_scale for the timestep (controlnet input)
|
||||
if isinstance(block_state.controlnet_keep[i], list):
|
||||
block_state.cond_scale = [
|
||||
c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])
|
||||
]
|
||||
else:
|
||||
controlnet_cond_scale = block_state.conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
|
||||
|
||||
# default controlnet output/unet input for guess mode + conditional path
|
||||
block_state.down_block_res_samples_zeros = None
|
||||
block_state.mid_block_res_sample_zeros = None
|
||||
|
||||
# guided denoiser step
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.unet)
|
||||
|
||||
# Prepare additional conditionings
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": guider_state_batch.text_embeds,
|
||||
"time_ids": guider_state_batch.time_ids,
|
||||
}
|
||||
if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None:
|
||||
added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds
|
||||
|
||||
# Prepare controlnet additional conditionings
|
||||
controlnet_added_cond_kwargs = {
|
||||
"text_embeds": guider_state_batch.text_embeds,
|
||||
"time_ids": guider_state_batch.time_ids,
|
||||
}
|
||||
# run controlnet for the guidance batch
|
||||
if block_state.guess_mode and not components.guider.is_conditional:
|
||||
# guider always run uncond batch first, so these tensors should be set already
|
||||
down_block_res_samples = block_state.down_block_res_samples_zeros
|
||||
mid_block_res_sample = block_state.mid_block_res_sample_zeros
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = components.controlnet(
|
||||
block_state.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=guider_state_batch.prompt_embeds,
|
||||
controlnet_cond=block_state.controlnet_cond,
|
||||
conditioning_scale=block_state.cond_scale,
|
||||
guess_mode=block_state.guess_mode,
|
||||
added_cond_kwargs=controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
**extra_controlnet_kwargs,
|
||||
)
|
||||
|
||||
# assign it to block_state so it will be available for the uncond guidance batch
|
||||
if block_state.down_block_res_samples_zeros is None:
|
||||
block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples]
|
||||
if block_state.mid_block_res_sample_zeros is None:
|
||||
block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample)
|
||||
|
||||
# Predict the noise
|
||||
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.unet(
|
||||
block_state.scaled_latents,
|
||||
t,
|
||||
encoder_hidden_states=guider_state_batch.prompt_embeds,
|
||||
timestep_cond=block_state.timestep_cond,
|
||||
cross_attention_kwargs=block_state.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.unet)
|
||||
|
||||
# Perform guidance
|
||||
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that update the latents. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
# YiYi TODO: move this out of here
|
||||
@staticmethod
|
||||
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
|
||||
accepted_kwargs = set(inspect.signature(func).parameters.keys())
|
||||
extra_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if key in accepted_kwargs and key not in exclude_kwargs:
|
||||
extra_kwargs[key] = value
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
block_state.extra_step_kwargs = self.prepare_extra_kwargs(
|
||||
components.scheduler.step, generator=block_state.generator, eta=block_state.eta
|
||||
)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
block_state.latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
**block_state.extra_step_kwargs,
|
||||
**block_state.scheduler_step_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != block_state.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
block_state.latents = block_state.latents.to(block_state.latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that update the latents (for inpainting workflow only). "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"mask",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"noise",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=Optional[torch.Tensor],
|
||||
description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
@staticmethod
|
||||
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
|
||||
accepted_kwargs = set(inspect.signature(func).parameters.keys())
|
||||
extra_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if key in accepted_kwargs and key not in exclude_kwargs:
|
||||
extra_kwargs[key] = value
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if components.num_channels_unet == 4:
|
||||
if block_state.image_latents is None:
|
||||
raise ValueError(f"image_latents is required for this step {self.__class__.__name__}")
|
||||
if block_state.mask is None:
|
||||
raise ValueError(f"mask is required for this step {self.__class__.__name__}")
|
||||
if block_state.noise is None:
|
||||
raise ValueError(f"noise is required for this step {self.__class__.__name__}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
block_state.extra_step_kwargs = self.prepare_extra_kwargs(
|
||||
components.scheduler.step, generator=block_state.generator, eta=block_state.eta
|
||||
)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
block_state.latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
**block_state.extra_step_kwargs,
|
||||
**block_state.scheduler_step_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != block_state.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
block_state.latents = block_state.latents.to(block_state.latents_dtype)
|
||||
|
||||
# adjust latent for inpainting
|
||||
if components.num_channels_unet == 4:
|
||||
block_state.init_latents_proper = block_state.image_latents
|
||||
if i < len(block_state.timesteps) - 1:
|
||||
block_state.noise_timestep = block_state.timesteps[i + 1]
|
||||
block_state.init_latents_proper = components.scheduler.add_noise(
|
||||
block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep])
|
||||
)
|
||||
|
||||
block_state.latents = (
|
||||
1 - block_state.mask
|
||||
) * block_state.init_latents_proper + block_state.mask * block_state.latents
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# the loop wrapper that iterates over the timesteps
|
||||
class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
||||
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
|
||||
if block_state.disable_guidance:
|
||||
components.guider.disable()
|
||||
else:
|
||||
components.guider.enable()
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
StableDiffusionXLLoopBeforeDenoiser,
|
||||
StableDiffusionXLLoopDenoiser,
|
||||
StableDiffusionXLLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLLoopAfterDenoiser`\n"
|
||||
"This block supports both text2img and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
# control_cond
|
||||
class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
StableDiffusionXLLoopBeforeDenoiser,
|
||||
StableDiffusionXLControlNetLoopDenoiser,
|
||||
StableDiffusionXLLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents with controlnet. \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLLoopAfterDenoiser`\n"
|
||||
"This block supports using controlnet for both text2img and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
# mask
|
||||
class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
StableDiffusionXLInpaintLoopBeforeDenoiser,
|
||||
StableDiffusionXLLoopDenoiser,
|
||||
StableDiffusionXLInpaintLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents(for inpainting task only). \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
|
||||
"This block onlysupports inpainting tasks."
|
||||
)
|
||||
|
||||
|
||||
# control_cond + mask
|
||||
class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
StableDiffusionXLInpaintLoopBeforeDenoiser,
|
||||
StableDiffusionXLControlNetLoopDenoiser,
|
||||
StableDiffusionXLInpaintLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
|
||||
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
|
||||
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
|
||||
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
|
||||
"This block only supports using controlnet for inpainting tasks."
|
||||
)
|
||||
902
src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
Normal file
902
src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
Normal file
@@ -0,0 +1,902 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"IP Adapter step that prepares ip adapter image embeddings.\n"
|
||||
"Note that this step only prepares the embeddings - in order for it to work correctly, "
|
||||
"you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n"
|
||||
"See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
||||
" for more details"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
|
||||
ComponentSpec(
|
||||
"feature_extractor",
|
||||
CLIPImageProcessor,
|
||||
config=FrozenDict({"size": 224, "crop_size": 224}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"ip_adapter_image",
|
||||
PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to be used as ip adapter",
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
||||
OutputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative IP adapter image embeddings",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components
|
||||
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(components.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = components.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = components.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = components.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self,
|
||||
components,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
prepare_unconditional_embeds,
|
||||
):
|
||||
image_embeds = []
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
components, single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||
else:
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
negative_image_embeds.append(single_negative_image_embeds)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
|
||||
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
components,
|
||||
ip_adapter_image=block_state.ip_adapter_image,
|
||||
ip_adapter_image_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1,
|
||||
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
|
||||
)
|
||||
if block_state.prepare_unconditional_embeds:
|
||||
block_state.negative_ip_adapter_embeds = []
|
||||
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
|
||||
negative_image_embeds, image_embeds = image_embeds.chunk(2)
|
||||
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
|
||||
block_state.ip_adapter_embeds[i] = image_embeds
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generate text_embeddings to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", CLIPTextModel),
|
||||
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_2"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("negative_prompt_2"),
|
||||
InputParam("cross_attention_kwargs"),
|
||||
InputParam("clip_skip"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
description="pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_pooled_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="guider_input_fields",
|
||||
description="negative pooled text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if block_state.prompt is not None and (
|
||||
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
elif block_state.prompt_2 is not None and (
|
||||
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
|
||||
):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str,
|
||||
prompt_2: Optional[str] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prepare_unconditional_embeds (`bool`):
|
||||
whether to use prepare unconditional embeddings or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
|
||||
components._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if components.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(components.text_encoder, lora_scale)
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(components.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = (
|
||||
[components.tokenizer, components.tokenizer_2]
|
||||
if components.tokenizer is not None
|
||||
else [components.tokenizer_2]
|
||||
)
|
||||
text_encoders = (
|
||||
[components.text_encoder, components.text_encoder_2]
|
||||
if components.text_encoder is not None
|
||||
else [components.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt_2]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
prompt = components.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
|
||||
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
dtype=components.text_encoder_2.dtype, device=device
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
if prepare_unconditional_embeds:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if components.text_encoder is not None:
|
||||
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(components.text_encoder, lora_scale)
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(components.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# Get inputs and intermediates
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# Encode input prompt
|
||||
block_state.text_encoder_lora_scale = (
|
||||
block_state.cross_attention_kwargs.get("scale", None)
|
||||
if block_state.cross_attention_kwargs is not None
|
||||
else None
|
||||
)
|
||||
(
|
||||
block_state.prompt_embeds,
|
||||
block_state.negative_prompt_embeds,
|
||||
block_state.pooled_prompt_embeds,
|
||||
block_state.negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
components,
|
||||
block_state.prompt,
|
||||
block_state.prompt_2,
|
||||
block_state.device,
|
||||
1,
|
||||
block_state.prepare_unconditional_embeds,
|
||||
block_state.negative_prompt,
|
||||
block_state.negative_prompt_2,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
clip_skip=block_state.clip_skip,
|
||||
)
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Encoder step that encode the input image into a latent representation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
)
|
||||
]
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
components.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||
|
||||
if components.vae.config.force_upcast:
|
||||
components.vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec(
|
||||
"mask_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict(
|
||||
{"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}
|
||||
),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that prepares the image and mask for the inpainting process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
|
||||
),
|
||||
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
|
||||
OutputParam(
|
||||
"masked_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)",
|
||||
),
|
||||
OutputParam(
|
||||
"crops_coords",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
|
||||
),
|
||||
]
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
components.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||
|
||||
if components.vae.config.force_upcast:
|
||||
components.vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
|
||||
# do not accept do_classifier_free_guidance
|
||||
def prepare_mask_latents(
|
||||
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
|
||||
if masked_image is not None and masked_image.shape[1] == 4:
|
||||
masked_image_latents = masked_image
|
||||
else:
|
||||
masked_image_latents = None
|
||||
|
||||
if masked_image is not None:
|
||||
if masked_image_latents is None:
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
||||
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.height is None:
|
||||
block_state.height = components.default_height
|
||||
if block_state.width is None:
|
||||
block_state.width = components.default_width
|
||||
|
||||
if block_state.padding_mask_crop is not None:
|
||||
block_state.crops_coords = components.mask_processor.get_crop_region(
|
||||
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
|
||||
)
|
||||
block_state.resize_mode = "fill"
|
||||
else:
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=block_state.resize_mode,
|
||||
)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
|
||||
block_state.mask = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
resize_mode=block_state.resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
block_state.masked_image,
|
||||
block_state.batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -0,0 +1,380 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
StableDiffusionXLControlNetInputStep,
|
||||
StableDiffusionXLControlNetUnionInputStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
)
|
||||
from .decoders import (
|
||||
StableDiffusionXLDecodeStep,
|
||||
StableDiffusionXLInpaintOverlayMaskStep,
|
||||
)
|
||||
from .denoise import (
|
||||
StableDiffusionXLControlNetDenoiseStep,
|
||||
StableDiffusionXLDenoiseStep,
|
||||
StableDiffusionXLInpaintControlNetDenoiseStep,
|
||||
StableDiffusionXLInpaintDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
StableDiffusionXLInpaintVaeEncoderStep,
|
||||
StableDiffusionXLIPAdapterStep,
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# auto blocks & sequential blocks & mappings
|
||||
|
||||
|
||||
# vae encoder (run before before_denoise)
|
||||
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for both inpainting and img2img tasks.\n"
|
||||
+ " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if neither `mask_image` nor `image` is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# optional ip-adapter (run before input step)
|
||||
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLIPAdapterStep]
|
||||
block_names = ["ip_adapter"]
|
||||
block_trigger_inputs = ["ip_adapter_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: img2img
|
||||
class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: inpainting
|
||||
class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img, inpainting)
|
||||
class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLInpaintBeforeDenoiseStep,
|
||||
StableDiffusionXLImg2ImgBeforeDenoiseStep,
|
||||
StableDiffusionXLBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n"
|
||||
+ " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n"
|
||||
+ " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
+ " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# optional controlnet input step (after before_denoise, before denoise)
|
||||
# works for both controlnet and controlnet_union
|
||||
class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
|
||||
block_names = ["controlnet_union", "controlnet"]
|
||||
block_trigger_inputs = ["control_mode", "control_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet Input step that prepare the controlnet input.\n"
|
||||
+ "This is an auto pipeline block that works for both controlnet and controlnet_union.\n"
|
||||
+ " (it should be called right before the denoise step)"
|
||||
+ " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n"
|
||||
+ " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided."
|
||||
+ " - if neither `control_mode` nor `control_image` is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# denoise: controlnet (text2img, img2img, inpainting)
|
||||
class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep]
|
||||
block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
|
||||
block_trigger_inputs = ["mask", "controlnet_cond"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents with controlnet. "
|
||||
"This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks."
|
||||
"This block should not be used without a controlnet_cond input"
|
||||
" - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
|
||||
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided."
|
||||
" - If neither mask nor controlnet_cond are provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# denoise: all task with or without controlnet (text2img, img2img, inpainting)
|
||||
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLAutoControlNetDenoiseStep,
|
||||
StableDiffusionXLInpaintDenoiseStep,
|
||||
StableDiffusionXLDenoiseStep,
|
||||
]
|
||||
block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["controlnet_cond", "mask", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
|
||||
" - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)."
|
||||
" - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)."
|
||||
" - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)."
|
||||
)
|
||||
|
||||
|
||||
# decode: inpaint
|
||||
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
|
||||
block_names = ["decode", "mask_overlay"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Inpaint decode step that decode the denoised latents into images outputs.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n"
|
||||
+ " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
|
||||
)
|
||||
|
||||
|
||||
# decode: all task (text2img, img2img, inpainting)
|
||||
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
|
||||
block_names = ["inpaint", "non-inpaint"]
|
||||
block_trigger_inputs = ["padding_mask_crop", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the denoised latents into images outputs.\n"
|
||||
+ "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n"
|
||||
+ " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n"
|
||||
+ " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
|
||||
)
|
||||
|
||||
|
||||
# ip-adapter, controlnet, text2img, img2img, inpainting
|
||||
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
StableDiffusionXLAutoControlNetInputStep,
|
||||
StableDiffusionXLAutoDenoiseStep,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"ip_adapter",
|
||||
"image_encoder",
|
||||
"before_denoise",
|
||||
"controlnet_input",
|
||||
"denoise",
|
||||
"decoder",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
|
||||
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# controlnet (input + denoise step)
|
||||
class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLAutoControlNetInputStep,
|
||||
StableDiffusionXLAutoControlNetDenoiseStep,
|
||||
]
|
||||
block_names = ["controlnet_input", "controlnet_denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet auto step that prepare the controlnet input and denoise the latents. "
|
||||
+ "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
|
||||
+ " (it should be replace at 'denoise' step)"
|
||||
)
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLInpaintDenoiseStep),
|
||||
("decode", StableDiffusionXLInpaintDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
CONTROLNET_BLOCKS = InsertableDict(
|
||||
[
|
||||
("denoise", StableDiffusionXLAutoControlnetStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
IP_ADAPTER_BLOCKS = InsertableDict(
|
||||
[
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
|
||||
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
|
||||
("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
|
||||
("denoise", StableDiffusionXLAutoDenoiseStep),
|
||||
("decode", StableDiffusionXLAutoDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2img": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"inpaint": INPAINT_BLOCKS,
|
||||
"controlnet": CONTROLNET_BLOCKS,
|
||||
"ip_adapter": IP_ADAPTER_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
@@ -0,0 +1,376 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
from ..modular_pipeline_utils import InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
|
||||
# YiYi Notes: model specific components:
|
||||
## (1) it should inherit from ModularPipeline
|
||||
## (2) acts like a container that holds components and configs
|
||||
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
|
||||
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
|
||||
## (5) how to use together with Components_manager?
|
||||
class StableDiffusionXLModularPipeline(
|
||||
ModularPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
ModularIPAdapterMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for Stable Diffusion XL.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
default_sample_size = 128
|
||||
if hasattr(self, "unet") and self.unet is not None:
|
||||
default_sample_size = self.unet.config.sample_size
|
||||
return default_sample_size
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_unet(self):
|
||||
num_channels_unet = 4
|
||||
if hasattr(self, "unet") and self.unet is not None:
|
||||
num_channels_unet = self.unet.config.in_channels
|
||||
return num_channels_unet
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 4
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
|
||||
# auto_docstring
|
||||
SDXL_INPUTS_SCHEMA = {
|
||||
"prompt": InputParam(
|
||||
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
|
||||
),
|
||||
"prompt_2": InputParam(
|
||||
"prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
|
||||
),
|
||||
"negative_prompt": InputParam(
|
||||
"negative_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts not to guide the image generation",
|
||||
),
|
||||
"negative_prompt_2": InputParam(
|
||||
"negative_prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The negative prompt or prompts for text_encoder_2",
|
||||
),
|
||||
"cross_attention_kwargs": InputParam(
|
||||
"cross_attention_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="Kwargs dictionary passed to the AttentionProcessor",
|
||||
),
|
||||
"clip_skip": InputParam(
|
||||
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
|
||||
),
|
||||
"image": InputParam(
|
||||
"image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to modify for img2img or inpainting",
|
||||
),
|
||||
"mask_image": InputParam(
|
||||
"mask_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Mask image for inpainting, white pixels will be repainted",
|
||||
),
|
||||
"generator": InputParam(
|
||||
"generator",
|
||||
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
|
||||
description="Generator(s) for deterministic generation",
|
||||
),
|
||||
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||
"num_images_per_prompt": InputParam(
|
||||
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
|
||||
),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
|
||||
),
|
||||
"timesteps": InputParam(
|
||||
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
|
||||
),
|
||||
"sigmas": InputParam(
|
||||
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
|
||||
),
|
||||
"denoising_end": InputParam(
|
||||
"denoising_end",
|
||||
type_hint=Optional[float],
|
||||
description="Fraction of denoising process to complete before termination",
|
||||
),
|
||||
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||
"strength": InputParam(
|
||||
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
|
||||
),
|
||||
"denoising_start": InputParam(
|
||||
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
|
||||
),
|
||||
"padding_mask_crop": InputParam(
|
||||
"padding_mask_crop",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Size of margin in crop for image and mask",
|
||||
),
|
||||
"original_size": InputParam(
|
||||
"original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Original size of the image for SDXL's micro-conditioning",
|
||||
),
|
||||
"target_size": InputParam(
|
||||
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
|
||||
),
|
||||
"negative_original_size": InputParam(
|
||||
"negative_original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on image resolution",
|
||||
),
|
||||
"negative_target_size": InputParam(
|
||||
"negative_target_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on target resolution",
|
||||
),
|
||||
"crops_coords_top_left": InputParam(
|
||||
"crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Top-left coordinates for SDXL's micro-conditioning",
|
||||
),
|
||||
"negative_crops_coords_top_left": InputParam(
|
||||
"negative_crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Negative conditioning crop coordinates",
|
||||
),
|
||||
"aesthetic_score": InputParam(
|
||||
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
|
||||
),
|
||||
"negative_aesthetic_score": InputParam(
|
||||
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
|
||||
),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"output_type": InputParam(
|
||||
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
|
||||
),
|
||||
"ip_adapter_image": InputParam(
|
||||
"ip_adapter_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Image(s) to be used as IP adapter",
|
||||
),
|
||||
"control_image": InputParam(
|
||||
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
|
||||
),
|
||||
"control_guidance_start": InputParam(
|
||||
"control_guidance_start",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=0.0,
|
||||
description="When ControlNet starts applying",
|
||||
),
|
||||
"control_guidance_end": InputParam(
|
||||
"control_guidance_end",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="When ControlNet stops applying",
|
||||
),
|
||||
"controlnet_conditioning_scale": InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="Scale factor for ControlNet outputs",
|
||||
),
|
||||
"guess_mode": InputParam(
|
||||
"guess_mode",
|
||||
type_hint=bool,
|
||||
default=False,
|
||||
description="Enables ControlNet encoder to recognize input without prompts",
|
||||
),
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
required=True,
|
||||
description="Text embeddings used to guide image generation",
|
||||
),
|
||||
"negative_prompt_embeds": InputParam(
|
||||
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
||||
),
|
||||
"pooled_prompt_embeds": InputParam(
|
||||
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
|
||||
),
|
||||
"negative_pooled_prompt_embeds": InputParam(
|
||||
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
||||
),
|
||||
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
"image_latents": InputParam(
|
||||
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
|
||||
),
|
||||
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||
"masked_image_latents": InputParam(
|
||||
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
||||
),
|
||||
"add_time_ids": InputParam(
|
||||
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
|
||||
),
|
||||
"negative_add_time_ids": InputParam(
|
||||
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
||||
),
|
||||
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"ip_adapter_embeds": InputParam(
|
||||
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
||||
),
|
||||
"negative_ip_adapter_embeds": InputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative image embeddings for IP-Adapter",
|
||||
),
|
||||
"images": InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
required=True,
|
||||
description="Generated images",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
|
||||
"prompt_embeds": OutputParam(
|
||||
"prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"
|
||||
),
|
||||
"negative_prompt_embeds": OutputParam(
|
||||
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
||||
),
|
||||
"pooled_prompt_embeds": OutputParam(
|
||||
"pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"
|
||||
),
|
||||
"negative_pooled_prompt_embeds": OutputParam(
|
||||
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
||||
),
|
||||
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
|
||||
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"image_latents": OutputParam(
|
||||
"image_latents", type_hint=torch.Tensor, description="Latents representing reference image"
|
||||
),
|
||||
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
|
||||
"masked_image_latents": OutputParam(
|
||||
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
||||
),
|
||||
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
|
||||
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
|
||||
"latent_timestep": OutputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"
|
||||
),
|
||||
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
|
||||
"negative_add_time_ids": OutputParam(
|
||||
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
||||
),
|
||||
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
|
||||
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"ip_adapter_embeds": OutputParam(
|
||||
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
||||
),
|
||||
"negative_ip_adapter_embeds": OutputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative image embeddings for IP-Adapter",
|
||||
),
|
||||
"images": OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="Generated images",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_OUTPUTS_SCHEMA = {
|
||||
"images": OutputParam(
|
||||
"images",
|
||||
type_hint=Union[
|
||||
Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput
|
||||
],
|
||||
description="The final generated images",
|
||||
)
|
||||
}
|
||||
@@ -248,14 +248,15 @@ def _get_connected_pipeline(pipeline_cls):
|
||||
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
|
||||
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
def get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
def _get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
|
||||
model_name = get_model(pipeline_class_name)
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
model_name = _get_model(pipeline_class_name)
|
||||
|
||||
if model_name is not None:
|
||||
task_class = mapping.get(model_name, None)
|
||||
|
||||
@@ -371,6 +371,22 @@ def maybe_raise_or_warn(
|
||||
)
|
||||
|
||||
|
||||
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
|
||||
def simple_get_class_obj(library_name, class_name):
|
||||
from diffusers import pipelines
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
return class_obj
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
@@ -452,7 +468,7 @@ def _get_pipeline_class(
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj.__name__ != "DiffusionPipeline":
|
||||
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
@@ -892,7 +908,10 @@ def _fetch_class_library_tuple(module):
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
if isinstance(not_compiled_module, type):
|
||||
class_name = not_compiled_module.__name__
|
||||
else:
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
|
||||
|
||||
@@ -1986,11 +1986,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
|
||||
}
|
||||
|
||||
optional_components = (
|
||||
pipeline._optional_components
|
||||
if hasattr(pipeline, "_optional_components") and pipeline._optional_components
|
||||
else []
|
||||
)
|
||||
missing_modules = (
|
||||
set(expected_modules)
|
||||
- set(pipeline._optional_components)
|
||||
- set(pipeline_kwargs.keys())
|
||||
- set(true_optional_modules)
|
||||
set(expected_modules) - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules)
|
||||
)
|
||||
|
||||
if len(missing_modules) > 0:
|
||||
|
||||
@@ -2,6 +2,126 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SkipLayerGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SmoothedEnergyGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TangentialClassifierFreeGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FasterCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -47,6 +167,21 @@ class HookRegistry(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayerSkipConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -62,6 +197,21 @@ class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SmoothedEnergyGuidanceConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
@@ -70,6 +220,10 @@ def apply_first_block_cache(*args, **kwargs):
|
||||
requires_backends(apply_first_block_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
@@ -1199,6 +1353,66 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ComponentsManager(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ComponentSpec(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModularPipelineBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch"])
|
||||
|
||||
|
||||
@@ -2,6 +2,36 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AllegroPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -20,8 +20,11 @@ import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib import request
|
||||
|
||||
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
||||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
||||
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
|
||||
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_diffusers_versions():
|
||||
@@ -154,33 +159,87 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
|
||||
def _raise_timeout_error(signum, frame):
|
||||
raise ValueError(
|
||||
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
||||
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
||||
)
|
||||
|
||||
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||
if trust_remote_code is None:
|
||||
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||
prev_sig_handler = None
|
||||
try:
|
||||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||
while trust_remote_code is None:
|
||||
answer = input(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||
f"Do you wish to run the custom code? [y/N] "
|
||||
)
|
||||
if answer.lower() in ["yes", "y", "1"]:
|
||||
trust_remote_code = True
|
||||
elif answer.lower() in ["no", "n", "0", ""]:
|
||||
trust_remote_code = False
|
||||
signal.alarm(0)
|
||||
except Exception:
|
||||
# OS which does not support signal.SIGALRM
|
||||
raise ValueError(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
finally:
|
||||
if prev_sig_handler is not None:
|
||||
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||
signal.alarm(0)
|
||||
elif has_remote_code:
|
||||
# For the CI which puts the timeout at 0
|
||||
_raise_timeout_error(None, None)
|
||||
|
||||
if has_remote_code and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||
" set the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
|
||||
return trust_remote_code
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path, force_reload=False):
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
# This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
|
||||
# separator. We do a bit of monkey patching to detect and fix this case.
|
||||
if not (
|
||||
pretrained_model_name_or_path is not None
|
||||
and "." in pretrained_model_name_or_path
|
||||
and module_path.startswith("diffusers_modules")
|
||||
and pretrained_model_name_or_path.replace("/", "--") in module_path
|
||||
):
|
||||
raise e # We can't figure this one out, just reraise the original error
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
|
||||
corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
|
||||
corrected_path = corrected_path.replace(
|
||||
pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
|
||||
pretrained_model_name_or_path.replace("/", "--"),
|
||||
)
|
||||
module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
if force_reload:
|
||||
sys.modules.pop(name, None)
|
||||
importlib.invalidate_caches()
|
||||
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||
|
||||
module: ModuleType
|
||||
if cached_module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
else:
|
||||
module = cached_module
|
||||
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
if class_name is None:
|
||||
return find_pipeline_class(module)
|
||||
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
@@ -472,4 +531,4 @@ def get_class_from_dynamic_module(
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)
|
||||
return get_class_in_module(class_name, final_module)
|
||||
|
||||
@@ -467,6 +467,7 @@ class PushToHubMixin:
|
||||
token: Optional[str] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
create_pr: bool = False,
|
||||
subfolder: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Uploads all files in `working_dir` to `repo_id`.
|
||||
@@ -481,7 +482,12 @@ class PushToHubMixin:
|
||||
|
||||
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
|
||||
return upload_folder(
|
||||
repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr
|
||||
repo_id=repo_id,
|
||||
folder_path=working_dir,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
path_in_repo=subfolder,
|
||||
)
|
||||
|
||||
def push_to_hub(
|
||||
@@ -493,6 +499,7 @@ class PushToHubMixin:
|
||||
create_pr: bool = False,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub.
|
||||
@@ -534,8 +541,9 @@ class PushToHubMixin:
|
||||
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
model_card = populate_model_card(model_card)
|
||||
if not subfolder:
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
model_card = populate_model_card(model_card)
|
||||
|
||||
# Save all files.
|
||||
save_kwargs = {"safe_serialization": safe_serialization}
|
||||
@@ -546,7 +554,8 @@ class PushToHubMixin:
|
||||
self.save_pretrained(tmpdir, **save_kwargs)
|
||||
|
||||
# Update model card if needed:
|
||||
model_card.save(os.path.join(tmpdir, "README.md"))
|
||||
if not subfolder:
|
||||
model_card.save(os.path.join(tmpdir, "README.md"))
|
||||
|
||||
return self._upload_folder(
|
||||
tmpdir,
|
||||
@@ -554,4 +563,5 @@ class PushToHubMixin:
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user