1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-06-25 11:26:36 +02:00
parent cb328d3ff9
commit 7d2a633e02
28 changed files with 828 additions and 1452 deletions

View File

@@ -794,8 +794,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
@@ -875,6 +875,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanTransformer3DModel,
WanVACETransformer3DModel,
)
from .modular_pipelines import (
ComponentsManager,
ComponentSpec,
ModularLoader,
ModularPipeline,
ModularPipelineBlocks,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -907,13 +914,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .modular_pipelines import (
ModularLoader,
ModularPipeline,
ModularPipelineBlocks,
ComponentSpec,
ComponentsManager,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
@@ -978,6 +978,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
@@ -1182,10 +1186,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
from .modular_pipelines import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):

View File

@@ -18,10 +18,11 @@ Usage example:
"""
import ast
from argparse import ArgumentParser, Namespace
from pathlib import Path
import importlib.util
import os
from argparse import ArgumentParser, Namespace
from pathlib import Path
from ..utils import logging
from . import BaseDiffusersCLICommand
@@ -57,7 +58,7 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# determine the block to be saved.
out = self._get_class_names(self.block_module_name)
classes_found = list({cls for cls, _ in out})
if self.block_class_name is not None:
child_class, parent_class = self._choose_block(out, self.block_class_name)
if child_class is None and parent_class is None:
@@ -125,9 +126,9 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
val = self._get_base_name(node.value)
return f"{val}.{node.attr}" if val else node.attr
return None
def _create_automap(self, parent_class, child_class):
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
return {"auto_map": auto_map}

View File

@@ -15,9 +15,9 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
from .custom_blocks import CustomBlocksCommand
def main():

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -74,10 +75,10 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.momentum_buffer = None
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
@@ -123,19 +124,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -160,25 +161,25 @@ def normalized_guidance(
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -113,18 +114,18 @@ class AutoGuidance(BaseGuidance):
if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
@@ -144,9 +145,9 @@ class AutoGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@@ -161,17 +162,17 @@ class AutoGuidance(BaseGuidance):
def _is_ag_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -74,12 +75,12 @@ class ClassifierFreeGuidance(BaseGuidance):
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
@@ -116,17 +117,17 @@ class ClassifierFreeGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -72,12 +73,12 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
@@ -106,7 +107,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@@ -121,19 +122,19 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -58,10 +58,10 @@ class BaseGuidance:
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
@@ -104,14 +104,14 @@ class BaseGuidance:
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
"""
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
@@ -119,7 +119,7 @@ class BaseGuidance:
modifications made during `prepare_models`.
"""
pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
@@ -139,15 +139,15 @@ class BaseGuidance:
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
"""

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -148,19 +149,19 @@ class SkipLayerGuidance(BaseGuidance):
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -204,7 +205,7 @@ class SkipLayerGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@@ -221,31 +222,31 @@ class SkipLayerGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -141,19 +142,19 @@ class SmoothedEnergyGuidance(BaseGuidance):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -197,7 +198,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@@ -214,31 +215,31 @@ class SmoothedEnergyGuidance(BaseGuidance):
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
@@ -63,10 +64,10 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
@@ -101,24 +102,24 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
def _is_tcfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
cond_dtype = pred_cond.dtype
cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
@@ -129,9 +130,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift
return pred

View File

@@ -20,7 +20,12 @@ import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
from ._common import (
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
_ATTENTION_CLASSES,
_FEEDFORWARD_CLASSES,
_get_submodule_from_fqn,
)
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
@@ -198,15 +203,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
@@ -215,7 +220,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name)
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -67,7 +67,7 @@ class SmoothedEnergyGuidanceHook(ModelHook):
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
@@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
@@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
@@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
in the future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
@@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
@@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone()
return query

View File

@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
ModularIPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,

View File

@@ -49,13 +49,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularLoader,
ModularPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
from .components_manager import ComponentsManager
else:
import sys

View File

@@ -12,24 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
import uuid
from collections import OrderedDict
from itertools import combinations
from typing import List, Optional, Union, Dict, Any
import copy
from typing import Any, Dict, List, Optional, Union
import torch
import time
from dataclasses import dataclass
from ..utils import (
is_accelerate_available,
logging,
)
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
import uuid
if is_accelerate_available():
@@ -237,12 +232,12 @@ class AutoOffloadStrategy:
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.added_time = OrderedDict() # Store when components were added
self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None
self._auto_offload_enabled = False
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
"""
Lookup component_ids by name, collection, or load_id.
@@ -251,7 +246,7 @@ class ComponentsManager:
components = self.components
if name:
ids_by_name = set()
ids_by_name = set()
for component_id, component in components.items():
comp_name = self._id_to_name(component_id)
if comp_name == name:
@@ -272,16 +267,16 @@ class ComponentsManager:
ids_by_load_id.add(name)
else:
ids_by_load_id = set(components.keys())
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
return ids
@staticmethod
def _id_to_name(component_id: str):
return "_".join(component_id.split("_")[:-1])
def add(self, name, component, collection: Optional[str] = None):
component_id = f"{name}_{uuid.uuid4()}"
# check for duplicated components
@@ -305,7 +300,7 @@ class ComponentsManager:
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
if components_with_same_load_id:
existing = ", ".join(components_with_same_load_id)
logger.warning(
@@ -320,7 +315,7 @@ class ComponentsManager:
if collection:
if collection not in self.collections:
self.collections[collection] = set()
if not component_id in self.collections[collection]:
if component_id not in self.collections[collection]:
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
for comp_id in comp_ids_in_collection:
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
@@ -331,8 +326,8 @@ class ComponentsManager:
logger.info(f"Added component '{name}' as '{component_id}'")
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
self.enable_auto_cpu_offload(self._auto_offload_device)
return component_id
@@ -341,14 +336,14 @@ class ComponentsManager:
if component_id not in self.components:
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
return
component = self.components.pop(component_id)
self.added_time.pop(component_id)
for collection in self.collections:
if component_id in self.collections[collection]:
self.collections[collection].remove(component_id)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
else:
@@ -386,7 +381,7 @@ class ComponentsManager:
Dictionary mapping component IDs to components,
or list of (base_name, component) tuples if as_name_component_tuples=True
"""
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
components = {k: self.components[k] for k in selected_ids}
@@ -397,16 +392,16 @@ class ComponentsManager:
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return component_id
if names is None:
if as_name_component_tuples:
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
else:
return components
# Create mapping from component_id to base_name for all components
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
def matches_pattern(component_id, pattern, exact_match=False):
"""
Helper function to check if a component matches a pattern based on its base name.
@@ -417,124 +412,124 @@ class ComponentsManager:
exact_match: If True, only exact matches to base_name are considered
"""
base_name = base_names[component_id]
# Exact match with base name
if exact_match:
return pattern == base_name
# Prefix match (ends with *)
elif pattern.endswith('*'):
prefix = pattern[:-1]
return base_name.startswith(prefix)
# Contains match (starts with *)
elif pattern.startswith('*'):
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
return search in base_name
# Exact match (no wildcards)
else:
return pattern == base_name
if isinstance(names, str):
# Check if this is a "not" pattern
is_not_pattern = names.startswith('!')
if is_not_pattern:
names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |)
if '|' in names:
terms = names.split('|')
matches = {}
for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
# Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern
if is_not_pattern:
should_include = not should_include
if should_include:
matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
# Prefix match (ends with *)
elif names.endswith('*'):
prefix = names[:-1]
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else:
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
# Contains match (starts with *)
elif names.startswith('*'):
search = names[1:-1] if names.endswith('*') else names[1:]
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()):
matches = {
comp_id: comp for comp_id, comp in components.items()
comp_id: comp for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
else:
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
else:
return matches
elif isinstance(names, list):
results = {}
for name in names:
result = self.get(name, collection, load_id, as_name_component_tuples=False)
results.update(result)
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
else:
return results
else:
raise ValueError(f"Invalid type for names: {type(names)}")
@@ -595,14 +590,14 @@ class ComponentsManager:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
component = self.components[component_id]
# Build complete info dict first
info = {
"model_id": component_id,
"added_time": self.added_time[component_id],
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None,
}
# Additional info for torch.nn.Module components
if isinstance(component, torch.nn.Module):
# Check for hook information
@@ -610,7 +605,7 @@ class ComponentsManager:
execution_device = None
if has_hook and hasattr(component._hf_hook, "execution_device"):
execution_device = component._hf_hook.execution_device
info.update({
"class_name": component.__class__.__name__,
"size_gb": get_memory_footprint(component) / (1024**3),
@@ -631,8 +626,8 @@ class ComponentsManager:
if any("IPAdapter" in ptype for ptype in processor_types):
# Then get scales only from IP-Adapter processors
scales = {
k: v.scale
for k, v in processors.items()
k: v.scale
for k, v in processors.items()
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
}
if scales:
@@ -646,7 +641,7 @@ class ComponentsManager:
else:
# List of fields requested, return dict with just those fields
return {k: v for k, v in info.items() if k in fields}
return info
def __repr__(self):
@@ -659,13 +654,13 @@ class ComponentsManager:
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return name
# Extract load_id if available
def get_load_id(component):
if hasattr(component, "_diffusers_load_id"):
return component._diffusers_load_id
return "N/A"
# Format device info compactly
def format_device(component, info):
if not info["has_hook"]:
@@ -674,18 +669,18 @@ class ComponentsManager:
device = str(getattr(component, 'device', 'N/A'))
exec_device = str(info['execution_device'] or 'N/A')
return f"{device}({exec_device})"
# Get all simple names to calculate width
simple_names = [get_simple_name(id) for id in self.components.keys()]
# Get max length of load_ids for models
load_ids = [
get_load_id(component)
for component in self.components.values()
get_load_id(component)
for component in self.components.values()
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
]
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
# Get all collections for each component
component_collections = {}
for name in self.components.keys():
@@ -695,11 +690,11 @@ class ComponentsManager:
component_collections[name].append(coll)
if not component_collections[name]:
component_collections[name] = ["N/A"]
# Find the maximum collection name length
all_collections = [coll for colls in component_collections.values() for coll in colls]
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
col_widths = {
"name": max(15, max(len(name) for name in simple_names)),
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
@@ -736,21 +731,21 @@ class ComponentsManager:
device_str = format_device(component, info)
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
load_id = get_load_id(component)
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | "
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
output += dash_line
# Other components section
@@ -766,17 +761,17 @@ class ComponentsManager:
for name, component in others.items():
info = self.get_model_info(name)
simple_name = get_simple_name(name)
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n"
output += dash_line
# Add additional component info
@@ -789,8 +784,8 @@ class ComponentsManager:
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
output += f" IP-Adapter: Enabled\n"
output += " IP-Adapter: Enabled\n"
return output
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
@@ -821,13 +816,13 @@ class ComponentsManager:
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
@@ -860,15 +855,15 @@ class ComponentsManager:
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
return self.components[component_id]
results = self.get(name, collection, load_id)
if not results:
raise ValueError(f"No components found matching '{name}'")
if len(results) > 1:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values()))
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
@@ -894,17 +889,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
if value_tuple not in value_to_keys:
value_to_keys[value_tuple] = []
value_to_keys[value_tuple].append(key)
def find_common_prefix(keys: List[str]) -> str:
"""Find the shortest common prefix among a list of dot-separated keys."""
if not keys:
return ""
if len(keys) == 1:
return keys[0]
# Split all keys into parts
key_parts = [k.split('.') for k in keys]
# Find how many initial parts are common
common_length = 0
for parts in zip(*key_parts):
@@ -912,10 +907,10 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
common_length += 1
else:
break
if common_length == 0:
return ""
# Return the common prefix
return '.'.join(key_parts[0][:common_length])
@@ -929,5 +924,5 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
summary[prefix] = value
else:
summary[""] = value # Use empty string if no common prefix
return summary

File diff suppressed because it is too large Load Diff

View File

@@ -12,44 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import inspect
from dataclasses import dataclass, asdict, field, fields
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
import re
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils.import_utils import is_torch_available
if is_torch_available():
import torch
pass
class InsertableOrderedDict(OrderedDict):
def insert(self, key, value, index):
items = list(self.items())
# Remove key if it already exists to avoid duplicates
items = [(k, v) for k, v in items if k != key]
# Insert at the specified index
items.insert(index, (key, value))
# Clear and update self
self.clear()
self.update(items)
# Return self for method chaining
return self
def __repr__(self):
if not self:
return "InsertableOrderedDict()"
items = []
for i, (key, value) in enumerate(self.items()):
items.append(f"{i}: ({repr(key)}, {repr(value)})")
return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])"
@@ -85,24 +86,24 @@ class ComponentSpec:
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec):
return False
return (self.name == other.name and
self.load_id == other.load_id and
return (self.name == other.name and
self.load_id == other.load_id and
self.default_creation_method == other.default_creation_method)
@classmethod
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` or `load` method")
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
@@ -113,19 +114,19 @@ class ComponentSpec:
"created with `ComponentSpec.load` method"
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
)
type_hint = component.__class__
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
config = component.config
else:
config = None
load_spec = cls.decode_load_id(component._diffusers_load_id)
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
@classmethod
def loading_fields(cls) -> List[str]:
"""
@@ -133,8 +134,8 @@ class ComponentSpec:
(i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property
def load_id(self) -> str:
"""
@@ -144,7 +145,7 @@ class ComponentSpec:
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
"""
@@ -165,26 +166,26 @@ class ComponentSpec:
If a segment value is "null", it's replaced with None.
Returns None if load_id is "null" (indicating component not created with `load` method).
"""
# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
if load_id == "null":
return result
# Split the load_id
parts = load_id.split("|")
# Map parts to loading fields by position
for i, part in enumerate(parts):
if i < len(loading_fields):
# Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part
return result
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
# the config info is lost in the process
@@ -194,11 +195,11 @@ class ComponentSpec:
if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError(
f"`type_hint` is required when using from_config creation method."
"`type_hint` is required when using from_config creation method."
)
config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs)
else:
@@ -211,17 +212,17 @@ class ComponentSpec:
if k in signature_params:
init_kwargs[k] = v
component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null"
if hasattr(component, "config"):
self.config = component.config
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
@@ -229,8 +230,8 @@ class ComponentSpec:
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
raise ValueError("`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None:
try:
from diffusers import AutoModel
@@ -244,17 +245,17 @@ class ComponentSpec:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@dataclass
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
@@ -281,7 +282,7 @@ class InputParam:
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
@@ -315,14 +316,14 @@ def format_inputs_short(inputs):
"""
required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str
if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str
@@ -353,18 +354,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
else:
inp_name = inp.name
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs}
modified_parts = []
new_output_parts = []
for out in intermediates_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
new_output_parts.append(out.name)
result = []
if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}")
@@ -372,7 +373,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)"
@@ -390,18 +391,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
"""
if not params:
return ""
base_indent = " " * indent_level
param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8)
formatted_params = []
def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation."""
words = text.split()
@@ -411,7 +412,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
for word in words:
word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line))
current_line = [word]
@@ -419,22 +420,22 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
else:
current_line.append(word)
current_length += word_length
if current_line:
lines.append(" ".join(current_line))
return f"\n{indent}".join(lines)
# Add the header
formatted_params.append(f"{base_indent}{header}:")
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
# YiYi Notes: remove this line if we remove kwargs_type
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
param_str = f"{param_indent}{name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):
if not param.required:
@@ -442,7 +443,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
if param.default is not None:
param_str += f", defaults to {param.default}"
param_str += "):"
# Add description on a new line with additional indentation and wrapping
if param.description:
desc = re.sub(
@@ -452,9 +453,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
@@ -500,42 +501,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
# Get the loading fields dynamically
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
@@ -553,27 +554,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
@@ -618,9 +619,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class
# Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output
return output

View File

@@ -1,16 +1,19 @@
from ..configuration_utils import ConfigMixin
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
from .modular_pipeline_utils import InputParam, OutputParam
from ..image_processor import PipelineImageInput
from pathlib import Path
import json
import os
from typing import Union, List, Optional, Tuple
import torch
import PIL
import numpy as np
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
@@ -189,8 +192,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
config_name = "node_config.json"
@@ -214,15 +217,15 @@ class ModularNode(ConfigMixin):
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
@@ -236,10 +239,10 @@ class ModularNode(ConfigMixin):
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
@@ -248,7 +251,7 @@ class ModularNode(ConfigMixin):
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
@@ -285,7 +288,7 @@ class ModularNode(ConfigMixin):
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
@@ -303,7 +306,7 @@ class ModularNode(ConfigMixin):
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
else:
outputs = self.blocks.intermediates_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
@@ -326,10 +329,10 @@ class ModularNode(ConfigMixin):
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
@@ -339,7 +342,7 @@ class ModularNode(ConfigMixin):
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components, collection=None):
self.blocks.setup_loader(component_manager=components, collection=collection)
self._components_manager = components
@@ -347,7 +350,7 @@ class ModularNode(ConfigMixin):
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
@@ -368,13 +371,13 @@ class ModularNode(ConfigMixin):
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
@@ -388,13 +391,13 @@ class ModularNode(ConfigMixin):
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
@@ -408,7 +411,7 @@ class ModularNode(ConfigMixin):
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
@@ -427,22 +430,22 @@ class ModularNode(ConfigMixin):
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {
"module": self.mellon_config,
"name_mapping": self.name_mapping
}
# Save the config to file
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
@@ -457,16 +460,16 @@ class ModularNode(ConfigMixin):
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
@@ -483,7 +486,7 @@ class ModularNode(ConfigMixin):
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
@@ -495,14 +498,14 @@ class ModularNode(ConfigMixin):
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.blocks.loader.update(**params_components)
output = self.blocks.run(**params_run, output=return_output_names)
return output

View File

@@ -34,11 +34,24 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
from .modular_loader import StableDiffusionXLModularLoader
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
from .encoders import (
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLTextEncoderStep,
)
from .modular_block_mappings import (
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
CONTROLNET_UNION_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
IP_ADAPTER_BLOCKS,
SDXL_SUPPORTED_BLOCKS,
TEXT2IMAGE_BLOCKS,
)
from .modular_loader import StableDiffusionXLModularLoader
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
else:
import sys

View File

@@ -13,32 +13,27 @@
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import Any, List, Optional, Tuple, Union
import PIL
import torch
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
from ...utils import logging
from ...utils.torch_utils import randn_tensor, unwrap_module
from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
from ...schedulers import EulerDiscreteScheduler
from ...configuration_utils import FrozenDict
from .modular_loader import StableDiffusionXLModularLoader
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from ...utils import logging
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import (
AutoPipelineBlocks,
ModularLoader,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_loader import StableDiffusionXLModularLoader
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -237,7 +232,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."),
InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."),
]
@property
def intermediates_outputs(self) -> List[str]:
return [
@@ -250,7 +245,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"),
OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"),
]
def check_inputs(self, components, block_state):
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
@@ -270,13 +265,13 @@ class StableDiffusionXLInputStep(PipelineBlock):
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list):
raise ValueError("`ip_adapter_embeds` must be a list")
if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list):
raise ValueError("`negative_ip_adapter_embeds` must be a list")
if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None:
for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape:
@@ -298,19 +293,19 @@ class StableDiffusionXLInputStep(PipelineBlock):
# duplicate text embeddings for each generation per prompt, using mps friendly method
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1)
block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1)
if block_state.negative_pooled_prompt_embeds is not None:
block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1)
if block_state.ip_adapter_embeds is not None:
for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0)
@@ -318,7 +313,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
if block_state.negative_ip_adapter_embeds is not None:
for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds):
block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0)
self.add_block_state(state, block_state)
return components, state
@@ -356,14 +351,14 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"),
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"),
]
@property
def intermediates_outputs(self) -> List[str]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"),
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"),
OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation")
]
@@ -455,7 +450,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
@@ -473,7 +468,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")]
@@ -524,7 +519,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
InputParam("num_images_per_prompt", default=1),
InputParam("denoising_start"),
InputParam(
"strength",
"strength",
default=0.9999,
description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` "
"will be used as a starting point, adding more noise to it the larger the `strength`. The number of "
@@ -540,46 +535,46 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
return [
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
),
InputParam(
"latent_timestep",
required=True,
type_hint=torch.Tensor,
"latent_timestep",
required=True,
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."
),
),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."
),
),
InputParam(
"mask",
required=True,
type_hint=torch.Tensor,
"mask",
required=True,
type_hint=torch.Tensor,
description="The mask for the inpainting generation. Can be generated in vae_encode step."
),
),
InputParam(
"masked_image_latents",
type_hint=torch.Tensor,
"masked_image_latents",
type_hint=torch.Tensor,
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step."
),
InputParam(
"dtype",
type_hint=torch.dtype,
"dtype",
type_hint=torch.dtype,
description="The dtype of the model inputs"
)
]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"),
return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"),
OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")]
@@ -587,13 +582,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
@staticmethod
def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
@@ -619,7 +614,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
return image_latents
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
def prepare_latents_inpaint(
@@ -737,15 +732,15 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
block_state.is_strength_max = block_state.strength == 1.0
# for non-inpainting specific unet, we do not need masked_image_latents
@@ -822,9 +817,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."),
InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."),
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."),
InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."),
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")]
@property
@@ -886,14 +881,14 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
return [
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
),
InputParam(
"dtype",
type_hint=torch.dtype,
"dtype",
type_hint=torch.dtype,
description="The dtype of the model inputs"
)
]
@@ -902,8 +897,8 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents",
type_hint=torch.Tensor,
"latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process"
)
]
@@ -980,7 +975,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("requires_aesthetics_score", False),]
@@ -1008,15 +1003,15 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."),
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")]
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components
@@ -1183,29 +1178,29 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
),
InputParam(
"pooled_prompt_embeds",
required=True,
type_hint=torch.Tensor,
"pooled_prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."
),
InputParam(
"batch_size",
required=True,
type_hint=int,
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")]
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components
@@ -1344,26 +1339,26 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"batch_size",
required=True,
type_hint=int,
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"crops_coords",
type_hint=Optional[Tuple[int]],
"crops_coords",
type_hint=Optional[Tuple[int]],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
),
]
@@ -1395,12 +1390,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
device,
dtype,
crops_coords=None,
):
):
if crops_coords is not None:
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
@@ -1416,9 +1411,9 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# (1) prepare controlnet inputs
block_state.device = components._execution_device
block_state.height, block_state.width = block_state.latents.shape[-2:]
@@ -1446,14 +1441,14 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
# (1.3)
# global_pool_conditions
# global_pool_conditions
block_state.global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
# (1.4)
# guess_mode
# guess_mode
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
# (1.5)
@@ -1501,12 +1496,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
]
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
block_state.controlnet_cond = block_state.control_image
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
self.add_block_state(state, block_state)
return components, state
@@ -1542,32 +1537,32 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step."
),
InputParam(
"batch_size",
required=True,
type_hint=int,
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of model tensor inputs. Can be generated in input step."
),
),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step."
),
InputParam(
"crops_coords",
type_hint=Optional[Tuple[int]],
"crops_coords",
type_hint=Optional[Tuple[int]],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
),
]
@@ -1599,12 +1594,12 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
device,
dtype,
crops_coords=None,
):
):
if crops_coords is not None:
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
@@ -1618,7 +1613,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
controlnet = unwrap_module(components.controlnet)
@@ -1651,7 +1646,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
if len(block_state.control_image) != len(block_state.control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")
# control_type
# control_type
block_state.num_control_type = controlnet.config.num_control_type
block_state.control_type = [0 for _ in range(block_state.num_control_type)]
for control_idx in block_state.control_mode:
@@ -1676,7 +1671,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
crops_coords=block_state.crops_coords,
)
block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
# controlnet_keep
block_state.controlnet_keep = []
for i in range(len(block_state.timesteps)):
@@ -1687,7 +1682,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
block_state.control_type_idx = block_state.control_mode
block_state.controlnet_cond = block_state.control_image
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
self.add_block_state(state, block_state)
return components, state
@@ -1698,7 +1693,7 @@ class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks):
block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
block_names = ["controlnet_union", "controlnet"]
block_trigger_inputs = ["control_mode", "control_image"]
@property
def description(self):
return "Controlnet Input step that prepare the controlnet input.\n" + \

View File

@@ -12,29 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
import numpy as np
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...configuration_utils import FrozenDict
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from ..modular_pipeline import (
AutoPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -44,15 +40,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@@ -160,10 +156,10 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [

View File

@@ -12,17 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple
import PIL
import torch
from collections import OrderedDict
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
USE_PEFT_BACKEND,
@@ -30,26 +35,10 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor, unwrap_module
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
from ...configuration_utils import FrozenDict
from transformers import (
CLIPTextModel,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...schedulers import EulerDiscreteScheduler
from ...guiders import ClassifierFreeGuidance
from ..modular_pipeline import AutoPipelineBlocks, PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_loader import StableDiffusionXLModularLoader
from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -71,7 +60,7 @@ def retrieve_latents(
class StableDiffusionXLIPAdapterStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
@@ -79,7 +68,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
@@ -87,8 +76,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@@ -97,8 +86,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam(
"ip_adapter_image",
PipelineImageInput,
"ip_adapter_image",
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter"
)
@@ -111,7 +100,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
@staticmethod
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
@@ -137,7 +126,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
@@ -219,7 +208,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
return(
"Text Encoder step that generate text_embeddings to guide the image generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
@@ -228,9 +217,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@@ -546,7 +535,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
@@ -558,9 +547,9 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@@ -576,7 +565,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
@property
@@ -586,13 +575,13 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
@@ -618,8 +607,8 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
return image_latents
@torch.no_grad()
@@ -628,7 +617,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
@@ -651,23 +640,23 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
"mask_processor",
VaeImageProcessor,
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
@@ -694,21 +683,21 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
@@ -734,7 +723,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
@@ -784,8 +773,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
@@ -801,7 +790,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode)
block_state.image = block_state.image.to(dtype=torch.float32)
@@ -834,7 +823,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file)
# Encode
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]

View File

@@ -13,44 +13,40 @@
# limitations under the License.
from ..modular_pipeline_utils import InsertableOrderedDict
from .before_denoise import (
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLInputStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
)
from .decoders import StableDiffusionXLAutoDecodeStep, StableDiffusionXLDecodeStep, StableDiffusionXLInpaintDecodeStep
# Import all the necessary block classes
from .denoise import (
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDenoiseLoop,
StableDiffusionXLInpaintDenoiseLoop
)
from .before_denoise import (
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep
StableDiffusionXLInpaintDenoiseLoop,
)
from .encoders import (
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLVaeEncoderStep,
StableDiffusionXLInpaintVaeEncoderStep,
StableDiffusionXLIPAdapterStep
)
from .decoders import (
StableDiffusionXLDecodeStep,
StableDiffusionXLInpaintDecodeStep,
StableDiffusionXLAutoDecodeStep
StableDiffusionXLIPAdapterStep,
StableDiffusionXLTextEncoderStep,
StableDiffusionXLVaeEncoderStep,
)
# YiYi notes: comment out for now, work on this later
# block mapping
# block mapping
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("input", StableDiffusionXLInputStep),

View File

@@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
import numpy as np
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...image_processor import PipelineImageInput
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...utils import logging
from ..modular_pipeline import ModularLoader
from ..modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .denoise import StableDiffusionXLAutoDenoiseStep
from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep
from .denoise import StableDiffusionXLAutoDenoiseStep
from .encoders import (
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -15,12 +15,12 @@
"""Utilities to dynamically load objects from the Hub."""
import importlib
import signal
import inspect
import json
import os
import re
import shutil
import signal
import sys
import threading
from pathlib import Path
@@ -531,4 +531,4 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module)
return get_class_in_module(class_name, final_module)