diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index a51a12ee28..6b56403a6f 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -38,6 +38,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ scipy \ tensorboard \ transformers \ - omegaconf + omegaconf \ + pytorch-lightning \ + xformers CMD ["/bin/bash"] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0b86dbe546..1bfaa02581 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -540,9 +540,14 @@ class LoRAAttnProcessor(nn.Module): self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim if input_ndim == 4: @@ -905,9 +910,13 @@ class XFormersAttnProcessor: hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, ): residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim if input_ndim == 4: @@ -1081,9 +1090,14 @@ class LoRAXFormersAttnProcessor(nn.Module): self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim if input_ndim == 4: @@ -1334,8 +1348,12 @@ class SlicedAttnAddedKVProcessor: def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7d5e6bcace..abddd48851 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -577,3 +577,9 @@ def enable_full_determinism(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False + + +def disable_full_determinism(): + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" + torch.use_deterministic_algorithms(False) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6140bf771e..b5d968e2a3 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -37,16 +37,18 @@ from diffusers import ( UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor +from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import ( CaptureLogger, + disable_full_determinism, enable_full_determinism, require_torch_2, require_torch_gpu, run_test_in_subprocess, ) +from ...models.test_lora_layers import create_unet_lora_layers from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -366,6 +368,56 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda") + def test_stable_diffusion_attn_processors(self): + disable_full_determinism() + device = "cuda" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + + # run normal sd pipe + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run xformers attention + sd_pipe.enable_xformers_memory_efficient_attention() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run attention slicing + sd_pipe.enable_attention_slicing() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run vae attention slicing + sd_pipe.enable_vae_slicing() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run lora attention + attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) + attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} + sd_pipe.unet.set_attn_processor(attn_processors) + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run lora xformers attention + attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) + attn_processors = { + k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim) + for k, v in attn_processors.items() + } + attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} + sd_pipe.unet.set_attn_processor(attn_processors) + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + enable_full_determinism() + def test_stable_diffusion_no_safety_checker(self): pipe = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None