1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add note about seg

This commit is contained in:
Aryan
2025-04-15 21:25:15 +02:00
parent 720783e508
commit b9bcd469f1
2 changed files with 15 additions and 2 deletions

View File

@@ -25,6 +25,9 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg
class SmoothedEnergyGuidance(BaseGuidance):
"""
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
in the future without warning or guarantee of reproducibility.
Args:
guidance_scale (`float`, defaults to `7.5`):

View File

@@ -113,6 +113,16 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
"""
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian
blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally,
this implementation also assumes that the visual tokens come from a square image/video. In practice, despite
these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results
for Smoothed Energy Guidance.
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
in the future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
@@ -139,10 +149,10 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
query_slice = F.pad(query_slice, padding, mode="reflect")
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
else:
query[:] = query.mean(dim=(-2, -1), keepdim=True)
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice
query[:, :num_square_tokens, :] = query_slice.clone()
return query