mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Re-add xformers enable to UNet2DCondition (#1627)
* finish * fix * Update tests/models/test_models_unet_2d.py * style Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
ff65c2d72b
commit
cd91fc06fe
@@ -188,6 +188,39 @@ class ModelMixin(torch.nn.Module):
|
||||
if self._supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
for module in self.children():
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers.utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
|
||||
Reference in New Issue
Block a user