mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix temb attention (#3607)
* Fix temb attention * Apply suggestions from code review * make style * Add tests and fix docker * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
c6ae883751
commit
c0f867afd1
@@ -38,6 +38,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf
|
||||
omegaconf \
|
||||
pytorch-lightning \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -540,9 +540,14 @@ class LoRAAttnProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
def __call__(
|
||||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
@@ -905,9 +910,13 @@ class XFormersAttnProcessor:
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
@@ -1081,9 +1090,14 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
def __call__(
|
||||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
@@ -1334,8 +1348,12 @@ class SlicedAttnAddedKVProcessor:
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -577,3 +577,9 @@ def enable_full_determinism():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def disable_full_determinism():
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
@@ -37,16 +37,18 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor
|
||||
from diffusers.utils import load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
disable_full_determinism,
|
||||
enable_full_determinism,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
|
||||
from ...models.test_lora_layers import create_unet_lora_layers
|
||||
from ...models.test_models_unet_2d_condition import create_lora_layers
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
@@ -366,6 +368,56 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda")
|
||||
def test_stable_diffusion_attn_processors(self):
|
||||
disable_full_determinism()
|
||||
device = "cuda" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
# run normal sd pipe
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run xformers attention
|
||||
sd_pipe.enable_xformers_memory_efficient_attention()
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run attention slicing
|
||||
sd_pipe.enable_attention_slicing()
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run vae attention slicing
|
||||
sd_pipe.enable_vae_slicing()
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run lora attention
|
||||
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
|
||||
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
|
||||
sd_pipe.unet.set_attn_processor(attn_processors)
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run lora xformers attention
|
||||
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
|
||||
attn_processors = {
|
||||
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
|
||||
for k, v in attn_processors.items()
|
||||
}
|
||||
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
|
||||
sd_pipe.unet.set_attn_processor(attn_processors)
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
def test_stable_diffusion_no_safety_checker(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
|
||||
|
||||
Reference in New Issue
Block a user