mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Re-add xformers enable to UNet2DCondition (#1627)
* finish * fix * Update tests/models/test_models_unet_2d.py * style Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
committed by
anton-
parent
3faf204c49
commit
dbfafb66f0
@@ -30,6 +30,7 @@ from diffusers.utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
|
||||
Reference in New Issue
Block a user