mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add note about seg
This commit is contained in:
@@ -25,6 +25,9 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
class SmoothedEnergyGuidance(BaseGuidance):
|
||||
"""
|
||||
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
|
||||
in the future without warning or guarantee of reproducibility.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
|
||||
@@ -113,6 +113,16 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth
|
||||
|
||||
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
|
||||
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
|
||||
"""
|
||||
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian
|
||||
blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally,
|
||||
this implementation also assumes that the visual tokens come from a square image/video. In practice, despite
|
||||
these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results
|
||||
for Smoothed Energy Guidance.
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
|
||||
in the future without warning or guarantee of reproducibility.
|
||||
"""
|
||||
assert query.ndim == 3
|
||||
|
||||
is_inf = sigma > sigma_threshold_inf
|
||||
@@ -139,10 +149,10 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma
|
||||
query_slice = F.pad(query_slice, padding, mode="reflect")
|
||||
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
|
||||
else:
|
||||
query[:] = query.mean(dim=(-2, -1), keepdim=True)
|
||||
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
|
||||
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query[:, :num_square_tokens, :] = query_slice
|
||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||
|
||||
return query
|
||||
|
||||
Reference in New Issue
Block a user