diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index bd2a61b894..2328aa82ec 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -25,6 +25,9 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg class SmoothedEnergyGuidance(BaseGuidance): """ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. Args: guidance_scale (`float`, defaults to `7.5`): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index 20df0de048..f0366e2988 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -113,6 +113,16 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth # Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian + blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, + this implementation also assumes that the visual tokens come from a square image/video. In practice, despite + these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results + for Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. + """ assert query.ndim == 3 is_inf = sigma > sigma_threshold_inf @@ -139,10 +149,10 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = F.pad(query_slice, padding, mode="reflect") query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) else: - query[:] = query.mean(dim=(-2, -1), keepdim=True) + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) query_slice = query_slice.permute(0, 2, 1) - query[:, :num_square_tokens, :] = query_slice + query[:, :num_square_tokens, :] = query_slice.clone() return query