From dbfafb66f064fa3274861ba85497efeaa247e454 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 9 Dec 2022 14:05:38 +0100 Subject: [PATCH] Re-add xformers enable to UNet2DCondition (#1627) * finish * fix * Update tests/models/test_models_unet_2d.py * style Co-authored-by: Anton Lozhkov --- src/diffusers/modeling_utils.py | 33 +++++++++++++++++++++++++++++ tests/models/test_models_unet_2d.py | 15 +++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 280dca0005..520b3e9311 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -188,6 +188,39 @@ class ModelMixin(torch.nn.Module): if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 4a2d5a96ed..9071495b58 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -30,6 +30,7 @@ from diffusers.utils import ( torch_all_close, torch_device, ) +from diffusers.utils.import_utils import is_xformers_available from parameterized import parameterized from ..test_modeling_common import ModelTesterMixin @@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers + ), "xformers is not enabled" + @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") def test_gradient_checkpointing(self): # enable deterministic behavior for gradient checkpointing