mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix upcast in slice attention (#1591)
* fix upcast in slice attention * fix dtype * add test * fix test
This commit is contained in:
@@ -649,9 +649,9 @@ class CrossAttention(nn.Module):
|
||||
key_slice = key_slice.float()
|
||||
|
||||
attn_slice = torch.baddbmm(
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx].transpose(-1, -2),
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
||||
query_slice,
|
||||
key_slice.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
@@ -265,6 +265,25 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_v_pred_upcast_attention(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 768, 768, 3)
|
||||
expected_slice = np.array([0.0461, 0.0483, 0.0566, 0.0512, 0.0446, 0.0751, 0.0664, 0.0551, 0.0488])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_v_pred_euler(self):
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)
|
||||
|
||||
Reference in New Issue
Block a user