From ced7c9601affad9fb24f7dce3311a424bbd6cdb8 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 7 Dec 2022 15:14:34 +0100 Subject: [PATCH] fix upcast in slice attention (#1591) * fix upcast in slice attention * fix dtype * add test * fix test --- src/diffusers/models/attention.py | 6 +++--- .../test_stable_diffusion_v_pred.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f09172be41..8b855a5ed5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -649,9 +649,9 @@ class CrossAttention(nn.Module): key_slice = key_slice.float() attn_slice = torch.baddbmm( - torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query[start_idx:end_idx], - key[start_idx:end_idx].transpose(-1, -2), + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), beta=0, alpha=self.scale, ) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index 5fe973b2ac..9755bf3b18 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -265,6 +265,25 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_v_pred_upcast_attention(self): + sd_pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np") + + image = output.images + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.0461, 0.0483, 0.0566, 0.0512, 0.0446, 0.0751, 0.0664, 0.0551, 0.0488]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_v_pred_euler(self): scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)