1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix upcast in slice attention (#1591)

* fix upcast in slice attention

* fix dtype

* add test

* fix test
This commit is contained in:
Suraj Patil
2022-12-07 15:14:34 +01:00
committed by GitHub
parent 8e74efad01
commit ced7c9601a
2 changed files with 22 additions and 3 deletions

View File

@@ -649,9 +649,9 @@ class CrossAttention(nn.Module):
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)

View File

@@ -265,6 +265,25 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_v_pred_upcast_attention(self):
sd_pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.enable_attention_slicing()
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
image = output.images
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 768, 768, 3)
expected_slice = np.array([0.0461, 0.0483, 0.0566, 0.0512, 0.0446, 0.0751, 0.0664, 0.0551, 0.0488])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_v_pred_euler(self):
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)