1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix temb attention (#3607)

* Fix temb attention

* Apply suggestions from code review

* make style

* Add tests and fix docker

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen
2023-05-30 11:26:23 +01:00
committed by GitHub
parent c6ae883751
commit c0f867afd1
4 changed files with 83 additions and 5 deletions

View File

@@ -38,6 +38,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
scipy \
tensorboard \
transformers \
omegaconf
omegaconf \
pytorch-lightning \
xformers
CMD ["/bin/bash"]

View File

@@ -540,9 +540,14 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
@@ -905,9 +910,13 @@ class XFormersAttnProcessor:
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
@@ -1081,9 +1090,14 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
@@ -1334,8 +1348,12 @@ class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

View File

@@ -577,3 +577,9 @@ def enable_full_determinism():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)

View File

@@ -37,16 +37,18 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import (
CaptureLogger,
disable_full_determinism,
enable_full_determinism,
require_torch_2,
require_torch_gpu,
run_test_in_subprocess,
)
from ...models.test_lora_layers import create_unet_lora_layers
from ...models.test_models_unet_2d_condition import create_lora_layers
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
@@ -366,6 +368,56 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda")
def test_stable_diffusion_attn_processors(self):
disable_full_determinism()
device = "cuda" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
# run normal sd pipe
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run xformers attention
sd_pipe.enable_xformers_memory_efficient_attention()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run attention slicing
sd_pipe.enable_attention_slicing()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run vae attention slicing
sd_pipe.enable_vae_slicing()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora xformers attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
for k, v in attn_processors.items()
}
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
enable_full_determinism()
def test_stable_diffusion_no_safety_checker(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None