diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 0a34499e09..afb4731055 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import math from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -606,64 +605,73 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): store_processor = CrossAttnStoreProcessor() self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample + map_size = None - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + def get_map_size(module, input, output): + nonlocal map_size + map_size = output.sample.shape[-2:] - # perform self-attention guidance with the stored self-attentnion map - if do_self_attention_guidance: - # classifier-free guidance produces two chunks of attention map - # and we only use unconditional one according to equation (24) - # in https://arxiv.org/pdf/2210.00939.pdf + with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance if do_classifier_free_guidance: - # DDIM-like prediction of x0 - pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) - # get the stored attention maps - uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) - # self-attention-based degrading of latents - degraded_latents = self.sag_masking( - pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) - ) - uncond_emb, _ = prompt_embeds.chunk(2) - # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample - noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) - else: - # DDIM-like prediction of x0 - pred_x0 = self.pred_x0(latents, noise_pred, t) - # get the stored attention maps - cond_attn = store_processor.attention_probs - # self-attention-based degrading of latents - degraded_latents = self.sag_masking( - pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) - ) - # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample - noise_pred += sag_scale * (noise_pred - degraded_pred) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # perform self-attention guidance with the stored self-attentnion map + if do_self_attention_guidance: + # classifier-free guidance produces two chunks of attention map + # and we only use unconditional one according to equation (24) + # in https://arxiv.org/pdf/2210.00939.pdf + if do_classifier_free_guidance: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) + # get the stored attention maps + uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) + ) + uncond_emb, _ = prompt_embeds.chunk(2) + # forward and give guidance + degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample + noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) + else: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred, t) + # get the stored attention maps + cond_attn = store_processor.attention_probs + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) + ) + # forward and give guidance + degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample + noise_pred += sag_scale * (noise_pred - degraded_pred) - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) @@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - def sag_masking(self, original_latents, attn_map, t, eps): + def sag_masking(self, original_latents, attn_map, map_size, t, eps): # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf bh, hw1, hw2 = attn_map.shape b, latent_channel, latent_h, latent_w = original_latents.shape h = self.unet.attention_head_dim if isinstance(h, list): h = h[-1] - map_size = math.isqrt(hw1) # Produce attention mask attn_map = attn_map.reshape(b, h, hw1, hw2) attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = ( - attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) + attn_mask.reshape(b, map_size[0], map_size[1]) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) ) attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py index ec2bb2185e..abaefbcad0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py @@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371]) assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 + + def test_stable_diffusion_2_non_square(self): + sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") + sag_pipe = sag_pipe.to(torch_device) + sag_pipe.set_progress_bar_config(disable=None) + + prompt = "." + generator = torch.manual_seed(0) + output = sag_pipe( + [prompt], + width=768, + height=512, + generator=generator, + guidance_scale=7.5, + sag_scale=1.0, + num_inference_steps=20, + output_type="np", + ) + + image = output.images + + assert image.shape == (1, 512, 768, 3)