mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] clean up and refactor gradient checkpointing tests (#9494)
* check. * fixes * fixes * updates * fixes * fixes
This commit is contained in:
@@ -39,7 +39,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -170,52 +169,17 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
@@ -329,9 +293,11 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@@ -364,9 +330,20 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DecoderTiny", "EncoderTiny"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip(
|
||||
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
|
||||
)
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
|
||||
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ConsistencyDecoderVAE
|
||||
@@ -443,55 +420,17 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase)
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Encoder", "TemporalDecoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
@@ -522,9 +461,11 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not tested.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -57,6 +58,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -785,6 +787,101 @@ class ModelTesterMixin:
|
||||
model.disable_gradient_checkpointing()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
|
||||
def test_gradient_checkpointing_is_applied(
|
||||
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
|
||||
):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
if self.model_class.__name__ in [
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
]:
|
||||
return
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
if attention_head_dim is not None:
|
||||
init_dict["attention_head_dim"] = attention_head_dim
|
||||
if num_attention_heads is not None:
|
||||
init_dict["num_attention_heads"] = num_attention_heads
|
||||
if block_out_channels is not None:
|
||||
init_dict["block_out_channels"] = block_out_channels
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||
|
||||
@@ -84,6 +84,13 @@ class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = Transformer2DModel.from_config(init_dict)
|
||||
assert isinstance(model, DiTTransformer2DModel)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DiTTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)
|
||||
|
||||
def test_correct_class_remapping_from_pretrained_config(self):
|
||||
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
|
||||
model = Transformer2DModel.from_config(config)
|
||||
|
||||
@@ -92,6 +92,10 @@ class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"PixArtTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_correct_class_remapping_from_dict_config(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = Transformer2DModel.from_config(init_dict)
|
||||
|
||||
@@ -77,3 +77,7 @@ class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"AllegroTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -74,6 +74,10 @@ class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"AuraFlowTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
@@ -81,3 +81,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogVideoXTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -83,3 +83,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogView3PlusTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -111,3 +111,7 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -86,3 +86,7 @@ class LatteTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
super().test_output(
|
||||
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"LatteTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -84,6 +84,10 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
@@ -139,3 +143,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -43,7 +43,6 @@ from diffusers.utils.testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -406,47 +405,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -599,31 +557,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
check_sliceable_dim_attr(module)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
@@ -631,9 +565,11 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
attention_head_dim = (8, 16)
|
||||
block_out_channels = (16, 32)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -269,37 +268,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
model_class_copy = copy.copy(UNetControlNetXSModel)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = model_class_copy(**init_dict)
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@is_flaky
|
||||
def test_forward_no_control(self):
|
||||
|
||||
@@ -161,27 +161,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
expected_set = {
|
||||
"CrossAttnUpBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"UNetMidBlockCrossAttnMotion",
|
||||
@@ -189,9 +169,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
|
||||
"Transformer2DModel",
|
||||
"DownBlockMotion",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -25,7 +25,6 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
skip_mps,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -160,47 +159,6 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_num_attention_heads_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -239,30 +197,7 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
@@ -270,9 +205,10 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
num_attention_heads = (8, 16)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, num_attention_heads=num_attention_heads
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
|
||||
Reference in New Issue
Block a user