mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support non square image generation for StableDiffusionSAGPipeline (#2629)
* Support non square image generation for StableDiffusionSAGPipeline * Fix style
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -606,64 +605,73 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
store_processor = CrossAttnStoreProcessor()
|
||||
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
map_size = None
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
def get_map_size(module, input, output):
|
||||
nonlocal map_size
|
||||
map_size = output.sample.shape[-2:]
|
||||
|
||||
# perform self-attention guidance with the stored self-attentnion map
|
||||
if do_self_attention_guidance:
|
||||
# classifier-free guidance produces two chunks of attention map
|
||||
# and we only use unconditional one according to equation (24)
|
||||
# in https://arxiv.org/pdf/2210.00939.pdf
|
||||
with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
# DDIM-like prediction of x0
|
||||
pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
|
||||
# get the stored attention maps
|
||||
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
|
||||
# self-attention-based degrading of latents
|
||||
degraded_latents = self.sag_masking(
|
||||
pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t)
|
||||
)
|
||||
uncond_emb, _ = prompt_embeds.chunk(2)
|
||||
# forward and give guidance
|
||||
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
|
||||
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
|
||||
else:
|
||||
# DDIM-like prediction of x0
|
||||
pred_x0 = self.pred_x0(latents, noise_pred, t)
|
||||
# get the stored attention maps
|
||||
cond_attn = store_processor.attention_probs
|
||||
# self-attention-based degrading of latents
|
||||
degraded_latents = self.sag_masking(
|
||||
pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t)
|
||||
)
|
||||
# forward and give guidance
|
||||
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
|
||||
noise_pred += sag_scale * (noise_pred - degraded_pred)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# perform self-attention guidance with the stored self-attentnion map
|
||||
if do_self_attention_guidance:
|
||||
# classifier-free guidance produces two chunks of attention map
|
||||
# and we only use unconditional one according to equation (24)
|
||||
# in https://arxiv.org/pdf/2210.00939.pdf
|
||||
if do_classifier_free_guidance:
|
||||
# DDIM-like prediction of x0
|
||||
pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
|
||||
# get the stored attention maps
|
||||
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
|
||||
# self-attention-based degrading of latents
|
||||
degraded_latents = self.sag_masking(
|
||||
pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t)
|
||||
)
|
||||
uncond_emb, _ = prompt_embeds.chunk(2)
|
||||
# forward and give guidance
|
||||
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
|
||||
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
|
||||
else:
|
||||
# DDIM-like prediction of x0
|
||||
pred_x0 = self.pred_x0(latents, noise_pred, t)
|
||||
# get the stored attention maps
|
||||
cond_attn = store_processor.attention_probs
|
||||
# self-attention-based degrading of latents
|
||||
degraded_latents = self.sag_masking(
|
||||
pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
|
||||
)
|
||||
# forward and give guidance
|
||||
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
|
||||
noise_pred += sag_scale * (noise_pred - degraded_pred)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
@@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
def sag_masking(self, original_latents, attn_map, t, eps):
|
||||
def sag_masking(self, original_latents, attn_map, map_size, t, eps):
|
||||
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
|
||||
bh, hw1, hw2 = attn_map.shape
|
||||
b, latent_channel, latent_h, latent_w = original_latents.shape
|
||||
h = self.unet.attention_head_dim
|
||||
if isinstance(h, list):
|
||||
h = h[-1]
|
||||
map_size = math.isqrt(hw1)
|
||||
|
||||
# Produce attention mask
|
||||
attn_map = attn_map.reshape(b, h, hw1, hw2)
|
||||
attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
|
||||
attn_mask = (
|
||||
attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype)
|
||||
attn_mask.reshape(b, map_size[0], map_size[1])
|
||||
.unsqueeze(1)
|
||||
.repeat(1, latent_channel, 1, 1)
|
||||
.type(attn_map.dtype)
|
||||
)
|
||||
attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
|
||||
|
||||
|
||||
@@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_2_non_square(self):
|
||||
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
|
||||
sag_pipe = sag_pipe.to(torch_device)
|
||||
sag_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "."
|
||||
generator = torch.manual_seed(0)
|
||||
output = sag_pipe(
|
||||
[prompt],
|
||||
width=768,
|
||||
height=512,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
sag_scale=1.0,
|
||||
num_inference_steps=20,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
assert image.shape == (1, 512, 768, 3)
|
||||
|
||||
Reference in New Issue
Block a user