mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
style
This commit is contained in:
@@ -794,8 +794,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
apply_layer_skip,
|
||||
apply_faster_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
@@ -875,6 +875,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ComponentsManager,
|
||||
ComponentSpec,
|
||||
ModularLoader,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -907,13 +914,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ScoreSdeVePipeline,
|
||||
StableDiffusionMixin,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ModularLoader,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
ComponentsManager,
|
||||
)
|
||||
from .quantizers import DiffusersQuantizer
|
||||
from .schedulers import (
|
||||
AmusedScheduler,
|
||||
@@ -978,6 +978,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
)
|
||||
from .pipelines import (
|
||||
AllegroPipeline,
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
@@ -1182,10 +1186,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
|
||||
@@ -18,10 +18,11 @@ Usage example:
|
||||
"""
|
||||
|
||||
import ast
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
@@ -57,7 +58,7 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
# determine the block to be saved.
|
||||
out = self._get_class_names(self.block_module_name)
|
||||
classes_found = list({cls for cls, _ in out})
|
||||
|
||||
|
||||
if self.block_class_name is not None:
|
||||
child_class, parent_class = self._choose_block(out, self.block_class_name)
|
||||
if child_class is None and parent_class is None:
|
||||
@@ -125,9 +126,9 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
val = self._get_base_name(node.value)
|
||||
return f"{val}.{node.attr}" if val else node.attr
|
||||
return None
|
||||
|
||||
|
||||
def _create_automap(self, parent_class, child_class):
|
||||
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
|
||||
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
|
||||
return {"auto_map": auto_map}
|
||||
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from .custom_blocks import CustomBlocksCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .fp16_safetensors import FP16SafetensorsCommand
|
||||
from .custom_blocks import CustomBlocksCommand
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -74,10 +75,10 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
@@ -123,19 +124,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
@@ -160,25 +161,25 @@ def normalized_guidance(
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
|
||||
return pred
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -113,18 +114,18 @@ class AutoGuidance(BaseGuidance):
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name in self._auto_guidance_hook_names:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
@@ -144,9 +145,9 @@ class AutoGuidance(BaseGuidance):
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
|
||||
return pred, {}
|
||||
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
@@ -161,17 +162,17 @@ class AutoGuidance(BaseGuidance):
|
||||
def _is_ag_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -74,12 +75,12 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
@@ -116,17 +117,17 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -72,12 +73,12 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
@@ -106,7 +107,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
@@ -121,19 +122,19 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
|
||||
@@ -58,10 +58,10 @@ class BaseGuidance:
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
@@ -104,14 +104,14 @@ class BaseGuidance:
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model preparation logic.
|
||||
"""
|
||||
self._count_prepared += 1
|
||||
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
|
||||
@@ -119,7 +119,7 @@ class BaseGuidance:
|
||||
modifications made during `prepare_models`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
@@ -139,15 +139,15 @@ class BaseGuidance:
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
|
||||
@property
|
||||
def is_unconditional(self) -> bool:
|
||||
return not self.is_conditional
|
||||
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -148,19 +149,19 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -204,7 +205,7 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
@@ -221,31 +222,31 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
|
||||
return is_within_range and not is_zero
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,6 +21,7 @@ from ..hooks import HookRegistry
|
||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -141,19 +142,19 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
||||
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
||||
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -197,7 +198,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
@@ -214,31 +215,31 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_seg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
|
||||
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
||||
|
||||
|
||||
return is_within_range and not is_zero
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
@@ -63,10 +64,10 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
|
||||
|
||||
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
@@ -101,24 +102,24 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
def _is_tcfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
|
||||
cond_dtype = pred_cond.dtype
|
||||
cond_dtype = pred_cond.dtype
|
||||
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
||||
preds = preds.flatten(2)
|
||||
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
||||
@@ -129,9 +130,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid
|
||||
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
||||
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
||||
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
||||
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
|
||||
return pred
|
||||
|
||||
@@ -20,7 +20,12 @@ import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
|
||||
from ._common import (
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_ATTENTION_CLASSES,
|
||||
_FEEDFORWARD_CLASSES,
|
||||
_get_submodule_from_fqn,
|
||||
)
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -198,15 +203,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
|
||||
blocks_found = True
|
||||
|
||||
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
@@ -215,7 +220,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
|
||||
if config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -67,7 +67,7 @@ class SmoothedEnergyGuidanceHook(ModelHook):
|
||||
|
||||
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
|
||||
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
||||
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
@@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
|
||||
if config._query_proj_identifiers is None:
|
||||
config._query_proj_identifiers = ["to_q"]
|
||||
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
|
||||
blocks_found = True
|
||||
|
||||
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
||||
continue
|
||||
@@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
|
||||
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
||||
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
@@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
||||
in the future without warning or guarantee of reproducibility.
|
||||
"""
|
||||
assert query.ndim == 3
|
||||
|
||||
|
||||
is_inf = sigma > sigma_threshold_inf
|
||||
batch_size, seq_len, embed_dim = query.shape
|
||||
|
||||
@@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
||||
query_slice = query[:, :num_square_tokens, :]
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
||||
|
||||
|
||||
if is_inf:
|
||||
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
||||
kernel_size_half = (kernel_size - 1) / 2
|
||||
@@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||
|
||||
|
||||
return query
|
||||
|
||||
@@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
ModularIPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
|
||||
@@ -49,13 +49,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularLoader,
|
||||
ModularPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
@@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
)
|
||||
from .components_manager import ComponentsManager
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -12,24 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from itertools import combinations
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from .modular_pipeline_utils import ComponentSpec
|
||||
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -237,12 +232,12 @@ class AutoOffloadStrategy:
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.collections = OrderedDict() # collection_name -> set of component_names
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
|
||||
|
||||
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
|
||||
"""
|
||||
Lookup component_ids by name, collection, or load_id.
|
||||
@@ -251,7 +246,7 @@ class ComponentsManager:
|
||||
components = self.components
|
||||
|
||||
if name:
|
||||
ids_by_name = set()
|
||||
ids_by_name = set()
|
||||
for component_id, component in components.items():
|
||||
comp_name = self._id_to_name(component_id)
|
||||
if comp_name == name:
|
||||
@@ -272,16 +267,16 @@ class ComponentsManager:
|
||||
ids_by_load_id.add(name)
|
||||
else:
|
||||
ids_by_load_id = set(components.keys())
|
||||
|
||||
|
||||
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
|
||||
return ids
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _id_to_name(component_id: str):
|
||||
return "_".join(component_id.split("_")[:-1])
|
||||
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None):
|
||||
|
||||
|
||||
component_id = f"{name}_{uuid.uuid4()}"
|
||||
|
||||
# check for duplicated components
|
||||
@@ -305,7 +300,7 @@ class ComponentsManager:
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
|
||||
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
|
||||
|
||||
|
||||
if components_with_same_load_id:
|
||||
existing = ", ".join(components_with_same_load_id)
|
||||
logger.warning(
|
||||
@@ -320,7 +315,7 @@ class ComponentsManager:
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
self.collections[collection] = set()
|
||||
if not component_id in self.collections[collection]:
|
||||
if component_id not in self.collections[collection]:
|
||||
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
|
||||
for comp_id in comp_ids_in_collection:
|
||||
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
|
||||
@@ -331,8 +326,8 @@ class ComponentsManager:
|
||||
logger.info(f"Added component '{name}' as '{component_id}'")
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
return component_id
|
||||
|
||||
|
||||
@@ -341,14 +336,14 @@ class ComponentsManager:
|
||||
if component_id not in self.components:
|
||||
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
|
||||
return
|
||||
|
||||
|
||||
component = self.components.pop(component_id)
|
||||
self.added_time.pop(component_id)
|
||||
|
||||
for collection in self.collections:
|
||||
if component_id in self.collections[collection]:
|
||||
self.collections[collection].remove(component_id)
|
||||
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
else:
|
||||
@@ -386,7 +381,7 @@ class ComponentsManager:
|
||||
Dictionary mapping component IDs to components,
|
||||
or list of (base_name, component) tuples if as_name_component_tuples=True
|
||||
"""
|
||||
|
||||
|
||||
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
|
||||
components = {k: self.components[k] for k in selected_ids}
|
||||
|
||||
@@ -397,16 +392,16 @@ class ComponentsManager:
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return component_id
|
||||
|
||||
|
||||
if names is None:
|
||||
if as_name_component_tuples:
|
||||
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
|
||||
else:
|
||||
return components
|
||||
|
||||
|
||||
# Create mapping from component_id to base_name for all components
|
||||
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
||||
|
||||
|
||||
def matches_pattern(component_id, pattern, exact_match=False):
|
||||
"""
|
||||
Helper function to check if a component matches a pattern based on its base name.
|
||||
@@ -417,124 +412,124 @@ class ComponentsManager:
|
||||
exact_match: If True, only exact matches to base_name are considered
|
||||
"""
|
||||
base_name = base_names[component_id]
|
||||
|
||||
|
||||
# Exact match with base name
|
||||
if exact_match:
|
||||
return pattern == base_name
|
||||
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif pattern.endswith('*'):
|
||||
prefix = pattern[:-1]
|
||||
return base_name.startswith(prefix)
|
||||
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif pattern.startswith('*'):
|
||||
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
|
||||
return search in base_name
|
||||
|
||||
|
||||
# Exact match (no wildcards)
|
||||
else:
|
||||
return pattern == base_name
|
||||
|
||||
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith('!')
|
||||
if is_not_pattern:
|
||||
names = names[1:] # Remove the ! prefix
|
||||
|
||||
|
||||
# Handle OR patterns (containing |)
|
||||
if '|' in names:
|
||||
terms = names.split('|')
|
||||
matches = {}
|
||||
|
||||
|
||||
for comp_id, comp in components.items():
|
||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
|
||||
|
||||
|
||||
# Check if any of the terms match this component
|
||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||
|
||||
|
||||
# Flip the decision if this is a NOT pattern
|
||||
if is_not_pattern:
|
||||
should_include = not should_include
|
||||
|
||||
|
||||
if should_include:
|
||||
matches[comp_id] = comp
|
||||
|
||||
|
||||
log_msg = "NOT " if is_not_pattern else ""
|
||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||
|
||||
|
||||
# Try exact match with a base name
|
||||
elif any(names == base_name for base_name in base_names.values()):
|
||||
# Find all components with this base name
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (base_names[comp_id] == names) != is_not_pattern
|
||||
}
|
||||
|
||||
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith('*'):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif names.startswith('*'):
|
||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (search in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
|
||||
# Substring match (no wildcards, but not an exact component name)
|
||||
elif any(names in base_name for base_name in base_names.values()):
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (names in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
|
||||
else:
|
||||
return matches
|
||||
|
||||
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name, collection, load_id, as_name_component_tuples=False)
|
||||
results.update(result)
|
||||
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
@@ -595,14 +590,14 @@ class ComponentsManager:
|
||||
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||
|
||||
component = self.components[component_id]
|
||||
|
||||
|
||||
# Build complete info dict first
|
||||
info = {
|
||||
"model_id": component_id,
|
||||
"added_time": self.added_time[component_id],
|
||||
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None,
|
||||
}
|
||||
|
||||
|
||||
# Additional info for torch.nn.Module components
|
||||
if isinstance(component, torch.nn.Module):
|
||||
# Check for hook information
|
||||
@@ -610,7 +605,7 @@ class ComponentsManager:
|
||||
execution_device = None
|
||||
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
||||
execution_device = component._hf_hook.execution_device
|
||||
|
||||
|
||||
info.update({
|
||||
"class_name": component.__class__.__name__,
|
||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||
@@ -631,8 +626,8 @@ class ComponentsManager:
|
||||
if any("IPAdapter" in ptype for ptype in processor_types):
|
||||
# Then get scales only from IP-Adapter processors
|
||||
scales = {
|
||||
k: v.scale
|
||||
for k, v in processors.items()
|
||||
k: v.scale
|
||||
for k, v in processors.items()
|
||||
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
||||
}
|
||||
if scales:
|
||||
@@ -646,7 +641,7 @@ class ComponentsManager:
|
||||
else:
|
||||
# List of fields requested, return dict with just those fields
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
|
||||
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
@@ -659,13 +654,13 @@ class ComponentsManager:
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return name
|
||||
|
||||
|
||||
# Extract load_id if available
|
||||
def get_load_id(component):
|
||||
if hasattr(component, "_diffusers_load_id"):
|
||||
return component._diffusers_load_id
|
||||
return "N/A"
|
||||
|
||||
|
||||
# Format device info compactly
|
||||
def format_device(component, info):
|
||||
if not info["has_hook"]:
|
||||
@@ -674,18 +669,18 @@ class ComponentsManager:
|
||||
device = str(getattr(component, 'device', 'N/A'))
|
||||
exec_device = str(info['execution_device'] or 'N/A')
|
||||
return f"{device}({exec_device})"
|
||||
|
||||
|
||||
# Get all simple names to calculate width
|
||||
simple_names = [get_simple_name(id) for id in self.components.keys()]
|
||||
|
||||
|
||||
# Get max length of load_ids for models
|
||||
load_ids = [
|
||||
get_load_id(component)
|
||||
for component in self.components.values()
|
||||
get_load_id(component)
|
||||
for component in self.components.values()
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
||||
]
|
||||
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
||||
|
||||
|
||||
# Get all collections for each component
|
||||
component_collections = {}
|
||||
for name in self.components.keys():
|
||||
@@ -695,11 +690,11 @@ class ComponentsManager:
|
||||
component_collections[name].append(coll)
|
||||
if not component_collections[name]:
|
||||
component_collections[name] = ["N/A"]
|
||||
|
||||
|
||||
# Find the maximum collection name length
|
||||
all_collections = [coll for colls in component_collections.values() for coll in colls]
|
||||
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
|
||||
|
||||
|
||||
col_widths = {
|
||||
"name": max(15, max(len(name) for name in simple_names)),
|
||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||
@@ -736,21 +731,21 @@ class ComponentsManager:
|
||||
device_str = format_device(component, info)
|
||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||
load_id = get_load_id(component)
|
||||
|
||||
|
||||
# Print first collection on the main line
|
||||
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
|
||||
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
||||
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
|
||||
|
||||
|
||||
# Print additional collections on separate lines if they exist
|
||||
for i in range(1, len(component_collections[name])):
|
||||
collection = component_collections[name][i]
|
||||
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | "
|
||||
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
|
||||
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
|
||||
|
||||
|
||||
output += dash_line
|
||||
|
||||
# Other components section
|
||||
@@ -766,17 +761,17 @@ class ComponentsManager:
|
||||
for name, component in others.items():
|
||||
info = self.get_model_info(name)
|
||||
simple_name = get_simple_name(name)
|
||||
|
||||
|
||||
# Print first collection on the main line
|
||||
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
|
||||
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
|
||||
|
||||
|
||||
# Print additional collections on separate lines if they exist
|
||||
for i in range(1, len(component_collections[name])):
|
||||
collection = component_collections[name][i]
|
||||
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n"
|
||||
|
||||
|
||||
output += dash_line
|
||||
|
||||
# Add additional component info
|
||||
@@ -789,8 +784,8 @@ class ComponentsManager:
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
output += f" IP-Adapter: Enabled\n"
|
||||
|
||||
output += " IP-Adapter: Enabled\n"
|
||||
|
||||
return output
|
||||
|
||||
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
@@ -821,13 +816,13 @@ class ComponentsManager:
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
|
||||
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
@@ -860,15 +855,15 @@ class ComponentsManager:
|
||||
if component_id not in self.components:
|
||||
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||
return self.components[component_id]
|
||||
|
||||
|
||||
results = self.get(name, collection, load_id)
|
||||
|
||||
|
||||
if not results:
|
||||
raise ValueError(f"No components found matching '{name}'")
|
||||
|
||||
|
||||
if len(results) > 1:
|
||||
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
||||
|
||||
|
||||
return next(iter(results.values()))
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -894,17 +889,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if value_tuple not in value_to_keys:
|
||||
value_to_keys[value_tuple] = []
|
||||
value_to_keys[value_tuple].append(key)
|
||||
|
||||
|
||||
def find_common_prefix(keys: List[str]) -> str:
|
||||
"""Find the shortest common prefix among a list of dot-separated keys."""
|
||||
if not keys:
|
||||
return ""
|
||||
if len(keys) == 1:
|
||||
return keys[0]
|
||||
|
||||
|
||||
# Split all keys into parts
|
||||
key_parts = [k.split('.') for k in keys]
|
||||
|
||||
|
||||
# Find how many initial parts are common
|
||||
common_length = 0
|
||||
for parts in zip(*key_parts):
|
||||
@@ -912,10 +907,10 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
common_length += 1
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
if common_length == 0:
|
||||
return ""
|
||||
|
||||
|
||||
# Return the common prefix
|
||||
return '.'.join(key_parts[0][:common_length])
|
||||
|
||||
@@ -929,5 +924,5 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
summary[prefix] = value
|
||||
else:
|
||||
summary[""] = value # Use empty string if no common prefix
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,44 +12,45 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import inspect
|
||||
from dataclasses import dataclass, asdict, field, fields
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
|
||||
|
||||
from ..utils.import_utils import is_torch_available
|
||||
from ..configuration_utils import FrozenDict, ConfigMixin
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
pass
|
||||
|
||||
|
||||
class InsertableOrderedDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
items = list(self.items())
|
||||
|
||||
|
||||
# Remove key if it already exists to avoid duplicates
|
||||
items = [(k, v) for k, v in items if k != key]
|
||||
|
||||
|
||||
# Insert at the specified index
|
||||
items.insert(index, (key, value))
|
||||
|
||||
|
||||
# Clear and update self
|
||||
self.clear()
|
||||
self.update(items)
|
||||
|
||||
|
||||
# Return self for method chaining
|
||||
return self
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
if not self:
|
||||
return "InsertableOrderedDict()"
|
||||
|
||||
|
||||
items = []
|
||||
for i, (key, value) in enumerate(self.items()):
|
||||
items.append(f"{i}: ({repr(key)}, {repr(value)})")
|
||||
|
||||
|
||||
return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])"
|
||||
|
||||
|
||||
@@ -85,24 +86,24 @@ class ComponentSpec:
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||
return hash((self.name, self.load_id, self.default_creation_method))
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Compare ComponentSpec objects based on name and load_id."""
|
||||
if not isinstance(other, ComponentSpec):
|
||||
return False
|
||||
return (self.name == other.name and
|
||||
self.load_id == other.load_id and
|
||||
return (self.name == other.name and
|
||||
self.load_id == other.load_id and
|
||||
self.default_creation_method == other.default_creation_method)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_component(cls, name: str, component: Any) -> Any:
|
||||
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
|
||||
|
||||
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
raise ValueError("Component is not created by `create` or `load` method")
|
||||
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
|
||||
@@ -113,19 +114,19 @@ class ComponentSpec:
|
||||
"created with `ComponentSpec.load` method"
|
||||
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
|
||||
)
|
||||
|
||||
|
||||
type_hint = component.__class__
|
||||
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
|
||||
|
||||
|
||||
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
|
||||
config = component.config
|
||||
else:
|
||||
config = None
|
||||
|
||||
|
||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||
|
||||
|
||||
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
|
||||
|
||||
|
||||
@classmethod
|
||||
def loading_fields(cls) -> List[str]:
|
||||
"""
|
||||
@@ -133,8 +134,8 @@ class ComponentSpec:
|
||||
(i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
"""
|
||||
@@ -144,7 +145,7 @@ class ComponentSpec:
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
@@ -165,26 +166,26 @@ class ComponentSpec:
|
||||
If a segment value is "null", it's replaced with None.
|
||||
Returns None if load_id is "null" (indicating component not created with `load` method).
|
||||
"""
|
||||
|
||||
|
||||
# Get all loading fields in order
|
||||
loading_fields = cls.loading_fields()
|
||||
result = {f: None for f in loading_fields}
|
||||
|
||||
if load_id == "null":
|
||||
return result
|
||||
|
||||
|
||||
# Split the load_id
|
||||
parts = load_id.split("|")
|
||||
|
||||
|
||||
# Map parts to loading fields by position
|
||||
for i, part in enumerate(parts):
|
||||
if i < len(loading_fields):
|
||||
# Convert "null" string back to None
|
||||
result[loading_fields[i]] = None if part == "null" else part
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
|
||||
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
|
||||
# the config info is lost in the process
|
||||
@@ -194,11 +195,11 @@ class ComponentSpec:
|
||||
|
||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||
raise ValueError(
|
||||
f"`type_hint` is required when using from_config creation method."
|
||||
"`type_hint` is required when using from_config creation method."
|
||||
)
|
||||
|
||||
|
||||
config = config or self.config or {}
|
||||
|
||||
|
||||
if issubclass(self.type_hint, ConfigMixin):
|
||||
component = self.type_hint.from_config(config, **kwargs)
|
||||
else:
|
||||
@@ -211,17 +212,17 @@ class ComponentSpec:
|
||||
if k in signature_params:
|
||||
init_kwargs[k] = v
|
||||
component = self.type_hint(**init_kwargs)
|
||||
|
||||
|
||||
component._diffusers_load_id = "null"
|
||||
if hasattr(component, "config"):
|
||||
self.config = component.config
|
||||
|
||||
|
||||
return component
|
||||
|
||||
|
||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||
def load(self, **kwargs) -> Any:
|
||||
"""Load component using from_pretrained."""
|
||||
|
||||
|
||||
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||
@@ -229,8 +230,8 @@ class ComponentSpec:
|
||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||
repo = load_kwargs.pop("repo", None)
|
||||
if repo is None:
|
||||
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||
|
||||
raise ValueError("`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
@@ -244,17 +245,17 @@ class ComponentSpec:
|
||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||
|
||||
|
||||
self.repo = repo
|
||||
for k, v in load_kwargs.items():
|
||||
setattr(self, k, v)
|
||||
component._diffusers_load_id = self.load_id
|
||||
|
||||
|
||||
return component
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
name: str
|
||||
@@ -281,7 +282,7 @@ class InputParam:
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
"""Specification for an output parameter."""
|
||||
name: str
|
||||
@@ -315,14 +316,14 @@ def format_inputs_short(inputs):
|
||||
"""
|
||||
required_inputs = [param for param in inputs if param.required]
|
||||
optional_inputs = [param for param in inputs if not param.required]
|
||||
|
||||
|
||||
required_str = ", ".join(param.name for param in required_inputs)
|
||||
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
|
||||
|
||||
|
||||
inputs_str = required_str
|
||||
if optional_str:
|
||||
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
|
||||
|
||||
|
||||
return inputs_str
|
||||
|
||||
|
||||
@@ -353,18 +354,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
||||
else:
|
||||
inp_name = inp.name
|
||||
input_parts.append(inp_name)
|
||||
|
||||
|
||||
# Handle modified variables (appear in both inputs and outputs)
|
||||
inputs_set = {inp.name for inp in intermediates_inputs}
|
||||
modified_parts = []
|
||||
new_output_parts = []
|
||||
|
||||
|
||||
for out in intermediates_outputs:
|
||||
if out.name in inputs_set:
|
||||
modified_parts.append(out.name)
|
||||
else:
|
||||
new_output_parts.append(out.name)
|
||||
|
||||
|
||||
result = []
|
||||
if input_parts:
|
||||
result.append(f" - inputs: {', '.join(input_parts)}")
|
||||
@@ -372,7 +373,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
||||
result.append(f" - modified: {', '.join(modified_parts)}")
|
||||
if new_output_parts:
|
||||
result.append(f" - outputs: {', '.join(new_output_parts)}")
|
||||
|
||||
|
||||
return "\n".join(result) if result else " (none)"
|
||||
|
||||
|
||||
@@ -390,18 +391,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
"""
|
||||
if not params:
|
||||
return ""
|
||||
|
||||
|
||||
base_indent = " " * indent_level
|
||||
param_indent = " " * (indent_level + 4)
|
||||
desc_indent = " " * (indent_level + 8)
|
||||
formatted_params = []
|
||||
|
||||
|
||||
def get_type_str(type_hint):
|
||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
|
||||
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
|
||||
return f"Union[{', '.join(types)}]"
|
||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||
|
||||
|
||||
def wrap_text(text, indent, max_length):
|
||||
"""Wrap text while preserving markdown links and maintaining indentation."""
|
||||
words = text.split()
|
||||
@@ -411,7 +412,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
|
||||
for word in words:
|
||||
word_length = len(word) + (1 if current_line else 0)
|
||||
|
||||
|
||||
if current_line and current_length + word_length > max_length:
|
||||
lines.append(" ".join(current_line))
|
||||
current_line = [word]
|
||||
@@ -419,22 +420,22 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
else:
|
||||
current_line.append(word)
|
||||
current_length += word_length
|
||||
|
||||
|
||||
if current_line:
|
||||
lines.append(" ".join(current_line))
|
||||
|
||||
|
||||
return f"\n{indent}".join(lines)
|
||||
|
||||
|
||||
# Add the header
|
||||
formatted_params.append(f"{base_indent}{header}:")
|
||||
|
||||
|
||||
for param in params:
|
||||
# Format parameter name and type
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
# YiYi Notes: remove this line if we remove kwargs_type
|
||||
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
|
||||
param_str = f"{param_indent}{name} (`{type_str}`"
|
||||
|
||||
|
||||
# Add optional tag and default value if parameter is an InputParam and optional
|
||||
if hasattr(param, "required"):
|
||||
if not param.required:
|
||||
@@ -442,7 +443,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
if param.default is not None:
|
||||
param_str += f", defaults to {param.default}"
|
||||
param_str += "):"
|
||||
|
||||
|
||||
# Add description on a new line with additional indentation and wrapping
|
||||
if param.description:
|
||||
desc = re.sub(
|
||||
@@ -452,9 +453,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
)
|
||||
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
||||
param_str += f"\n{desc_indent}{wrapped_desc}"
|
||||
|
||||
|
||||
formatted_params.append(param_str)
|
||||
|
||||
|
||||
return "\n\n".join(formatted_params)
|
||||
|
||||
|
||||
@@ -500,42 +501,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
|
||||
"""
|
||||
if not components:
|
||||
return ""
|
||||
|
||||
|
||||
base_indent = " " * indent_level
|
||||
component_indent = " " * (indent_level + 4)
|
||||
formatted_components = []
|
||||
|
||||
|
||||
# Add the header
|
||||
formatted_components.append(f"{base_indent}Components:")
|
||||
if add_empty_lines:
|
||||
formatted_components.append("")
|
||||
|
||||
|
||||
# Add each component with optional empty lines between them
|
||||
for i, component in enumerate(components):
|
||||
# Get type name, handling special cases
|
||||
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
|
||||
|
||||
|
||||
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
|
||||
if component.description:
|
||||
component_desc += f": {component.description}"
|
||||
|
||||
|
||||
# Get the loading fields dynamically
|
||||
loading_field_values = []
|
||||
for field_name in component.loading_fields():
|
||||
field_value = getattr(component, field_name)
|
||||
if field_value is not None:
|
||||
loading_field_values.append(f"{field_name}={field_value}")
|
||||
|
||||
|
||||
# Add loading field information if available
|
||||
if loading_field_values:
|
||||
component_desc += f" [{', '.join(loading_field_values)}]"
|
||||
|
||||
|
||||
formatted_components.append(component_desc)
|
||||
|
||||
|
||||
# Add an empty line after each component except the last one
|
||||
if add_empty_lines and i < len(components) - 1:
|
||||
formatted_components.append("")
|
||||
|
||||
|
||||
return "\n".join(formatted_components)
|
||||
|
||||
|
||||
@@ -553,27 +554,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
"""
|
||||
if not configs:
|
||||
return ""
|
||||
|
||||
|
||||
base_indent = " " * indent_level
|
||||
config_indent = " " * (indent_level + 4)
|
||||
formatted_configs = []
|
||||
|
||||
|
||||
# Add the header
|
||||
formatted_configs.append(f"{base_indent}Configs:")
|
||||
if add_empty_lines:
|
||||
formatted_configs.append("")
|
||||
|
||||
|
||||
# Add each config with optional empty lines between them
|
||||
for i, config in enumerate(configs):
|
||||
config_desc = f"{config_indent}{config.name} (default: {config.default})"
|
||||
if config.description:
|
||||
config_desc += f": {config.description}"
|
||||
formatted_configs.append(config_desc)
|
||||
|
||||
|
||||
# Add an empty line after each config except the last one
|
||||
if add_empty_lines and i < len(configs) - 1:
|
||||
formatted_configs.append("")
|
||||
|
||||
|
||||
return "\n".join(formatted_configs)
|
||||
|
||||
|
||||
@@ -618,9 +619,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
|
||||
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam, OutputParam
|
||||
from ..image_processor import PipelineImageInput
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
|
||||
from typing import Union, List, Optional, Tuple
|
||||
import torch
|
||||
import PIL
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..image_processor import PipelineImageInput
|
||||
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||
@@ -189,8 +192,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||
if group_key in name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
class ModularNode(ConfigMixin):
|
||||
|
||||
config_name = "node_config.json"
|
||||
@@ -214,15 +217,15 @@ class ModularNode(ConfigMixin):
|
||||
self.name_mapping = {}
|
||||
|
||||
input_params = {}
|
||||
# pass or create a default param dict for each input
|
||||
# pass or create a default param dict for each input
|
||||
# e.g. for prompt,
|
||||
# prompt = {
|
||||
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
@@ -236,10 +239,10 @@ class ModularNode(ConfigMixin):
|
||||
if mellon_name != inp.name:
|
||||
self.name_mapping[inp.name] = mellon_name
|
||||
continue
|
||||
|
||||
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
|
||||
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
continue
|
||||
|
||||
|
||||
if inp.name in DEFAULT_PARAM_MAPS:
|
||||
# first check if it's in the default param map, if so, directly use that
|
||||
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||
@@ -248,7 +251,7 @@ class ModularNode(ConfigMixin):
|
||||
if inp.name not in self.name_mapping:
|
||||
self.name_mapping[inp.name] = param
|
||||
else:
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# 1. use the type hint to determine the type
|
||||
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||
if inp.type_hint is not None:
|
||||
@@ -285,7 +288,7 @@ class ModularNode(ConfigMixin):
|
||||
break
|
||||
if to_exclude:
|
||||
continue
|
||||
|
||||
|
||||
if get_group_name(comp.name):
|
||||
param = get_group_name(comp.name)
|
||||
if comp.name not in self.name_mapping:
|
||||
@@ -303,7 +306,7 @@ class ModularNode(ConfigMixin):
|
||||
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediates_outputs
|
||||
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
if param:
|
||||
@@ -326,10 +329,10 @@ class ModularNode(ConfigMixin):
|
||||
param = out.name
|
||||
# add the param dict to the outputs dict
|
||||
output_params[out.name] = param
|
||||
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused kwargs: {kwargs}")
|
||||
|
||||
|
||||
register_dict = {
|
||||
"category": category,
|
||||
"label": label,
|
||||
@@ -339,7 +342,7 @@ class ModularNode(ConfigMixin):
|
||||
"name_mapping": self.name_mapping,
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
|
||||
def setup(self, components, collection=None):
|
||||
self.blocks.setup_loader(component_manager=components, collection=collection)
|
||||
self._components_manager = components
|
||||
@@ -347,7 +350,7 @@ class ModularNode(ConfigMixin):
|
||||
@property
|
||||
def mellon_config(self):
|
||||
return self._convert_to_mellon_config()
|
||||
|
||||
|
||||
def _convert_to_mellon_config(self):
|
||||
|
||||
node = {}
|
||||
@@ -368,13 +371,13 @@ class ModularNode(ConfigMixin):
|
||||
}
|
||||
else:
|
||||
param = inp_param
|
||||
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||
|
||||
|
||||
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
if comp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[comp_name]
|
||||
@@ -388,13 +391,13 @@ class ModularNode(ConfigMixin):
|
||||
}
|
||||
else:
|
||||
param = comp_param
|
||||
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||
|
||||
|
||||
|
||||
for out_name, out_param in self.config.output_params.items():
|
||||
if out_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[out_name]
|
||||
@@ -408,7 +411,7 @@ class ModularNode(ConfigMixin):
|
||||
}
|
||||
else:
|
||||
param = out_param
|
||||
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
@@ -427,22 +430,22 @@ class ModularNode(ConfigMixin):
|
||||
Path: Path to the saved config file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
|
||||
# Create a combined dictionary with module definition and name mapping
|
||||
config = {
|
||||
"module": self.mellon_config,
|
||||
"name_mapping": self.name_mapping
|
||||
}
|
||||
|
||||
|
||||
# Save the config to file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
|
||||
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
@@ -457,16 +460,16 @@ class ModularNode(ConfigMixin):
|
||||
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
logger.info(f"Mellon config loaded from {file_path}")
|
||||
|
||||
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def process_inputs(self, **kwargs):
|
||||
@@ -483,7 +486,7 @@ class ModularNode(ConfigMixin):
|
||||
if comp:
|
||||
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||
|
||||
|
||||
|
||||
params_run = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
logger.debug(f"input: {inp_name}")
|
||||
@@ -495,14 +498,14 @@ class ModularNode(ConfigMixin):
|
||||
inp = kwargs.pop(mellon_inp_name)
|
||||
if inp is not None:
|
||||
params_run[inp_name] = inp
|
||||
|
||||
|
||||
return_output_names = list(self.config.output_params.keys())
|
||||
|
||||
return params_components, params_run, return_output_names
|
||||
|
||||
def execute(self, **kwargs):
|
||||
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||
|
||||
|
||||
self.blocks.loader.update(**params_components)
|
||||
output = self.blocks.run(**params_run, output=return_output_names)
|
||||
return output
|
||||
|
||||
@@ -34,11 +34,24 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
|
||||
from .decoders import StableDiffusionXLAutoDecodeStep
|
||||
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
|
||||
from .encoders import (
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
)
|
||||
from .modular_block_mappings import (
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
CONTROLNET_UNION_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
IP_ADAPTER_BLOCKS,
|
||||
SDXL_SUPPORTED_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
)
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -13,32 +13,27 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
||||
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
||||
from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel
|
||||
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
ModularLoader,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -237,7 +232,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."),
|
||||
InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."),
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
return [
|
||||
@@ -250,7 +245,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"),
|
||||
OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"),
|
||||
]
|
||||
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
|
||||
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
|
||||
@@ -270,13 +265,13 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
|
||||
if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list):
|
||||
raise ValueError("`ip_adapter_embeds` must be a list")
|
||||
|
||||
|
||||
if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list):
|
||||
raise ValueError("`negative_ip_adapter_embeds` must be a list")
|
||||
|
||||
|
||||
if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None:
|
||||
for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
|
||||
if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape:
|
||||
@@ -298,19 +293,19 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1)
|
||||
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1)
|
||||
|
||||
|
||||
block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1)
|
||||
|
||||
|
||||
if block_state.negative_pooled_prompt_embeds is not None:
|
||||
block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1)
|
||||
|
||||
|
||||
if block_state.ip_adapter_embeds is not None:
|
||||
for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
|
||||
block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0)
|
||||
@@ -318,7 +313,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
if block_state.negative_ip_adapter_embeds is not None:
|
||||
for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds):
|
||||
block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0)
|
||||
|
||||
|
||||
self.add_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -356,14 +351,14 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"),
|
||||
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"),
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"),
|
||||
OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation")
|
||||
]
|
||||
|
||||
@@ -455,7 +450,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return [
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
@@ -473,7 +468,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")]
|
||||
|
||||
|
||||
@@ -524,7 +519,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
InputParam(
|
||||
"strength",
|
||||
"strength",
|
||||
default=0.9999,
|
||||
description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` "
|
||||
"will be used as a starting point, adding more noise to it the larger the `strength`. The number of "
|
||||
@@ -540,46 +535,46 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"mask",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"mask",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask for the inpainting generation. Can be generated in vae_encode step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"masked_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
"masked_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step."
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs"
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
|
||||
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"),
|
||||
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"),
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
|
||||
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"),
|
||||
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"),
|
||||
OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")]
|
||||
|
||||
|
||||
@@ -587,13 +582,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
@staticmethod
|
||||
def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator):
|
||||
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
@@ -619,7 +614,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
return image_latents
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
|
||||
def prepare_latents_inpaint(
|
||||
@@ -737,15 +732,15 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
|
||||
block_state.is_strength_max = block_state.strength == 1.0
|
||||
|
||||
# for non-inpainting specific unet, we do not need masked_image_latents
|
||||
@@ -822,9 +817,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."),
|
||||
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
|
||||
InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."),
|
||||
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")]
|
||||
|
||||
@property
|
||||
@@ -886,14 +881,14 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs"
|
||||
)
|
||||
]
|
||||
@@ -902,8 +897,8 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents",
|
||||
type_hint=torch.Tensor,
|
||||
"latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process"
|
||||
)
|
||||
]
|
||||
@@ -980,7 +975,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [ConfigSpec("requires_aesthetics_score", False),]
|
||||
@@ -1008,15 +1003,15 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
|
||||
InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."),
|
||||
InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
|
||||
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
|
||||
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
|
||||
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components
|
||||
@@ -1183,29 +1178,29 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"pooled_prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"pooled_prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
|
||||
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
|
||||
return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"),
|
||||
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"),
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components
|
||||
@@ -1344,26 +1339,26 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
|
||||
),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
|
||||
),
|
||||
InputParam(
|
||||
"crops_coords",
|
||||
type_hint=Optional[Tuple[int]],
|
||||
"crops_coords",
|
||||
type_hint=Optional[Tuple[int]],
|
||||
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
|
||||
),
|
||||
]
|
||||
@@ -1395,12 +1390,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
device,
|
||||
dtype,
|
||||
crops_coords=None,
|
||||
):
|
||||
):
|
||||
if crops_coords is not None:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
|
||||
else:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
@@ -1416,9 +1411,9 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
block_state.device = components._execution_device
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
@@ -1446,14 +1441,14 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
# (1.3)
|
||||
# global_pool_conditions
|
||||
# global_pool_conditions
|
||||
block_state.global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
# (1.4)
|
||||
# guess_mode
|
||||
# guess_mode
|
||||
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
|
||||
|
||||
# (1.5)
|
||||
@@ -1501,12 +1496,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
|
||||
]
|
||||
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
|
||||
block_state.controlnet_cond = block_state.control_image
|
||||
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
|
||||
|
||||
|
||||
|
||||
|
||||
self.add_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -1542,32 +1537,32 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step."
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of model tensor inputs. Can be generated in input step."
|
||||
),
|
||||
),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step."
|
||||
),
|
||||
InputParam(
|
||||
"crops_coords",
|
||||
type_hint=Optional[Tuple[int]],
|
||||
"crops_coords",
|
||||
type_hint=Optional[Tuple[int]],
|
||||
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
|
||||
),
|
||||
]
|
||||
@@ -1599,12 +1594,12 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
device,
|
||||
dtype,
|
||||
crops_coords=None,
|
||||
):
|
||||
):
|
||||
if crops_coords is not None:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
|
||||
else:
|
||||
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
@@ -1618,7 +1613,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
@@ -1651,7 +1646,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
if len(block_state.control_image) != len(block_state.control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
|
||||
# control_type
|
||||
# control_type
|
||||
block_state.num_control_type = controlnet.config.num_control_type
|
||||
block_state.control_type = [0 for _ in range(block_state.num_control_type)]
|
||||
for control_idx in block_state.control_mode:
|
||||
@@ -1676,7 +1671,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
|
||||
|
||||
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
for i in range(len(block_state.timesteps)):
|
||||
@@ -1687,7 +1682,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
block_state.control_type_idx = block_state.control_mode
|
||||
block_state.controlnet_cond = block_state.control_image
|
||||
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
|
||||
|
||||
|
||||
self.add_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -1698,7 +1693,7 @@ class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
|
||||
block_names = ["controlnet_union", "controlnet"]
|
||||
block_trigger_inputs = ["control_mode", "control_image"]
|
||||
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Controlnet Input step that prepare the controlnet input.\n" + \
|
||||
|
||||
@@ -12,29 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
|
||||
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from ..modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -44,15 +40,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@@ -160,10 +156,10 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
return [
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,17 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor, PipelineImageInput
|
||||
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
||||
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -30,26 +35,10 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
|
||||
from ..modular_pipeline import AutoPipelineBlocks, PipelineBlock, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -71,7 +60,7 @@ def retrieve_latents(
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
@@ -79,7 +68,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
||||
" for more details"
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
@@ -87,8 +76,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
@@ -97,8 +86,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"ip_adapter_image",
|
||||
PipelineImageInput,
|
||||
"ip_adapter_image",
|
||||
PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to be used as ip adapter"
|
||||
)
|
||||
@@ -111,7 +100,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
||||
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
|
||||
]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
|
||||
@staticmethod
|
||||
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
@@ -137,7 +126,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
|
||||
@@ -219,7 +208,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return(
|
||||
"Text Encoder step that generate text_embeddings to guide the image generation"
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
@@ -228,9 +217,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@@ -546,7 +535,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
@@ -558,9 +547,9 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@@ -576,7 +565,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
|
||||
|
||||
@property
|
||||
@@ -586,13 +575,13 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
@@ -618,8 +607,8 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -628,7 +617,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
|
||||
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
@@ -651,23 +640,23 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec(
|
||||
"mask_processor",
|
||||
VaeImageProcessor,
|
||||
"mask_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -694,21 +683,21 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
|
||||
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
|
||||
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
|
||||
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
|
||||
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
|
||||
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
|
||||
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
@@ -734,7 +723,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
return image_latents
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
|
||||
# do not accept do_classifier_free_guidance
|
||||
@@ -784,8 +773,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
|
||||
@@ -801,7 +790,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
else:
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
|
||||
|
||||
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
|
||||
@@ -834,7 +823,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
|
||||
# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file)
|
||||
# Encode
|
||||
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@@ -13,44 +13,40 @@
|
||||
# limitations under the License.
|
||||
|
||||
from ..modular_pipeline_utils import InsertableOrderedDict
|
||||
from .before_denoise import (
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
StableDiffusionXLControlNetInputStep,
|
||||
StableDiffusionXLControlNetUnionInputStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
)
|
||||
from .decoders import StableDiffusionXLAutoDecodeStep, StableDiffusionXLDecodeStep, StableDiffusionXLInpaintDecodeStep
|
||||
|
||||
# Import all the necessary block classes
|
||||
from .denoise import (
|
||||
StableDiffusionXLAutoDenoiseStep,
|
||||
StableDiffusionXLControlNetDenoiseStep,
|
||||
StableDiffusionXLDenoiseLoop,
|
||||
StableDiffusionXLInpaintDenoiseLoop
|
||||
)
|
||||
from .before_denoise import (
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLImg2ImgSetTimestepsStep,
|
||||
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
||||
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLInpaintPrepareLatentsStep,
|
||||
StableDiffusionXLControlNetInputStep,
|
||||
StableDiffusionXLControlNetUnionInputStep
|
||||
StableDiffusionXLInpaintDenoiseLoop,
|
||||
)
|
||||
from .encoders import (
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLVaeEncoderStep,
|
||||
StableDiffusionXLInpaintVaeEncoderStep,
|
||||
StableDiffusionXLIPAdapterStep
|
||||
)
|
||||
from .decoders import (
|
||||
StableDiffusionXLDecodeStep,
|
||||
StableDiffusionXLInpaintDecodeStep,
|
||||
StableDiffusionXLAutoDecodeStep
|
||||
StableDiffusionXLIPAdapterStep,
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
# YiYi notes: comment out for now, work on this later
|
||||
# block mapping
|
||||
# block mapping
|
||||
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
|
||||
@@ -12,20 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from ...utils import logging
|
||||
|
||||
from ..modular_pipeline import ModularLoader
|
||||
from ..modular_pipeline_utils import InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@@ -12,14 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
|
||||
from .denoise import StableDiffusionXLAutoDenoiseStep
|
||||
from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep
|
||||
from .decoders import StableDiffusionXLAutoDecodeStep
|
||||
from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep
|
||||
from .denoise import StableDiffusionXLAutoDenoiseStep
|
||||
from .encoders import (
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
"""Utilities to dynamically load objects from the Hub."""
|
||||
|
||||
import importlib
|
||||
import signal
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
@@ -531,4 +531,4 @@ def get_class_from_dynamic_module(
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module)
|
||||
return get_class_in_module(class_name, final_module)
|
||||
|
||||
Reference in New Issue
Block a user