1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Re-add xformers enable to UNet2DCondition (#1627)

* finish

* fix

* Update tests/models/test_models_unet_2d.py

* style

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Patrick von Platen
2022-12-09 14:05:38 +01:00
committed by anton-
parent 3faf204c49
commit dbfafb66f0
2 changed files with 48 additions and 0 deletions

View File

@@ -30,6 +30,7 @@ from diffusers.utils import (
torch_all_close,
torch_device,
)
from diffusers.utils.import_utils import is_xformers_available
from parameterized import parameterized
from ..test_modeling_common import ModelTesterMixin
@@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing