1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

smoothed energy guidance

This commit is contained in:
Aryan
2025-04-15 21:13:13 +02:00
parent ba579f4da9
commit 720783e508
10 changed files with 431 additions and 21 deletions

View File

@@ -137,6 +137,7 @@ else:
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
]
)
_import_structure["hooks"].extend(
@@ -145,6 +146,7 @@ else:
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"LayerSkipConfig",
"SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
@@ -729,12 +731,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
)
from .hooks import (
FasterCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache,
apply_pyramid_attention_broadcast,

View File

@@ -23,5 +23,6 @@ if is_torch_available():
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance]

View File

@@ -55,10 +55,10 @@ class BaseGuidance:
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def force_disable(self):
def _force_disable(self):
self._enabled = False
def force_enable(self):
def _force_enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:

View File

@@ -24,8 +24,9 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg
class SkipLayerGuidance(BaseGuidance):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
https://huggingface.co/papers/2411.18664
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional

View File

@@ -0,0 +1,250 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
class SmoothedEnergyGuidance(BaseGuidance):
"""
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
seg_guidance_scale (`float`, defaults to `3.0`):
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
seg_blur_sigma (`float`, defaults to `9999999.0`):
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
The threshold above which the blur is considered infinite.
seg_guidance_start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not
provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of
`SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
def __init__(
self,
guidance_scale: float = 7.5,
seg_guidance_scale: float = 2.8,
seg_blur_sigma: float = 9999999.0,
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
self.seg_blur_sigma = seg_blur_sigma
self.seg_blur_threshold_inf = seg_blur_threshold_inf
self.seg_guidance_start = seg_guidance_start
self.seg_guidance_stop = seg_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= seg_guidance_start < 1.0):
raise ValueError(
f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}."
)
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
raise ValueError(
f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}."
)
if seg_guidance_layers is None and seg_guidance_config is None:
raise ValueError(
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
)
if seg_guidance_layers is not None and seg_guidance_config is not None:
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
if seg_guidance_layers is not None:
if isinstance(seg_guidance_layers, int):
seg_guidance_layers = [seg_guidance_layers]
if not isinstance(seg_guidance_layers, list):
raise ValueError(
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
seg_guidance_config = [seg_guidance_config]
if not isinstance(seg_guidance_config, list):
raise ValueError(
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
)
self.seg_guidance_config = seg_guidance_config
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
if self._is_seg_enabled() and self.is_conditional and self._num_outputs_prepared > 0:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if arg is None or isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_seg_enabled():
# If we're predicting pred_cond and pred_cond_seg only, we need to set the key to pred_cond_seg
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_cond_seg"
self._preds[key] = pred
if key == "pred_cond_seg":
# If we are in SLG mode, we need to remove the hooks after inference
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_cond_seg
pred = pred + self.seg_guidance_scale * shift
elif not self._is_seg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_seg = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_seg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -8,3 +8,4 @@ if is_torch_available():
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig

View File

@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
@@ -30,3 +34,10 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
if submodule_name == fqn:
return submodule
return None

View File

@@ -20,7 +20,7 @@ import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
@@ -66,12 +66,13 @@ class LayerSkipConfig:
def __post_init__(self):
if not (0 <= self.dropout <= 1):
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __init__(self) -> None:
super().__init__()
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
@@ -226,10 +227,3 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
if submodule_name == fqn:
return submodule
return None

View File

@@ -0,0 +1,148 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from ..utils import get_logger
from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
@dataclass
class SmoothedEnergyGuidanceConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`.
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`.
If `None`, `to_q` is used by default.
"""
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: List[str] = None
class SmoothedEnergyGuidanceHook(ModelHook):
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
super().__init__()
self.blur_sigma = blur_sigma
self.blur_threshold_inf = blur_threshold_inf
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
return smoothed_output
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
for identifier in config._query_proj_identifiers:
query_proj = getattr(submodule, identifier, None)
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
continue
logger.debug(
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
)
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
seq_len_sqrt = int(math.sqrt(seq_len))
num_square_tokens = seq_len_sqrt * seq_len_sqrt
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
kernel1d = kernel1d.to(query)
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
query_slice = F.pad(query_slice, padding, mode="reflect")
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
else:
query[:] = query.mean(dim=(-2, -1), keepdim=True)
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice
return query

View File

@@ -2231,9 +2231,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
data.num_channels_unet = pipeline.unet.config.in_channels
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
if data.disable_guidance:
pipeline.guider.force_disable()
pipeline.guider._force_disable()
else:
pipeline.guider.force_enable()
pipeline.guider._force_enable()
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
@@ -2626,9 +2626,9 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# (2) Prepare conditional inputs for unet using the guider
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
if data.disable_guidance:
pipeline.guider.force_disable()
pipeline.guider._force_disable()
else:
pipeline.guider.force_enable()
pipeline.guider._force_enable()
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta)
@@ -3039,9 +3039,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
if data.disable_guidance:
pipeline.guider.force_disable()
pipeline.guider._force_disable()
else:
pipeline.guider.force_enable()
pipeline.guider._force_enable()
data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype)
repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0]