1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] clean up and refactor gradient checkpointing tests (#9494)

* check.

* fixes

* fixes

* updates

* fixes

* fixes
This commit is contained in:
Sayak Paul
2024-10-31 18:24:19 +05:30
committed by GitHub
parent 8ce37ab055
commit 4adf6affbb
15 changed files with 180 additions and 273 deletions

View File

@@ -39,7 +39,6 @@ from diffusers.utils.testing_utils import (
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
@@ -170,52 +169,17 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Not tested.")
def test_forward_signature(self):
pass
@unittest.skip("Not tested.")
def test_training(self):
pass
@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
@@ -329,9 +293,11 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Not tested.")
def test_forward_signature(self):
pass
@unittest.skip("Not tested.")
def test_forward_with_norm_groups(self):
pass
@@ -364,9 +330,20 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Not tested.")
def test_outputs_equivalence(self):
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DecoderTiny", "EncoderTiny"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip(
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
)
def test_effective_gradient_checkpointing(self):
pass
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
@@ -443,55 +420,17 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase)
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Not tested.")
def test_forward_signature(self):
pass
@unittest.skip("Not tested.")
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
@@ -522,9 +461,11 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Not tested.")
def test_forward_signature(self):
pass
@unittest.skip("Not tested.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import json
import os
@@ -57,6 +58,7 @@ from diffusers.utils.testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_all_close,
torch_device,
)
@@ -785,6 +787,101 @@ class ModelTesterMixin:
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)
@require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
def test_gradient_checkpointing_is_applied(
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
if self.model_class.__name__ in [
"UNetSpatioTemporalConditionModel",
"AutoencoderKLTemporalDecoder",
]:
return
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if attention_head_dim is not None:
init_dict["attention_head_dim"] = attention_head_dim
if num_attention_heads is not None:
init_dict["num_attention_heads"] = num_attention_heads
if block_out_channels is not None:
init_dict["block_out_channels"] = block_out_channels
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

View File

@@ -84,6 +84,13 @@ class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
model = Transformer2DModel.from_config(init_dict)
assert isinstance(model, DiTTransformer2DModel)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DiTTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)
def test_correct_class_remapping_from_pretrained_config(self):
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
model = Transformer2DModel.from_config(config)

View File

@@ -92,6 +92,10 @@ class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"PixArtTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_correct_class_remapping_from_dict_config(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = Transformer2DModel.from_config(init_dict)

View File

@@ -77,3 +77,7 @@ class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"AllegroTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -74,6 +74,10 @@ class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"AuraFlowTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
def test_set_attn_processor_for_determinism(self):
pass

View File

@@ -81,3 +81,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -83,3 +83,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView3PlusTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -111,3 +111,7 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -86,3 +86,7 @@ class LatteTransformerTests(ModelTesterMixin, unittest.TestCase):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LatteTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -84,6 +84,10 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
def test_set_attn_processor_for_determinism(self):
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
@@ -139,3 +143,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -43,7 +43,6 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
@@ -406,47 +405,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
== "XFormersAttnProcessor"
), "xformers is not enabled"
@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -599,31 +557,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
@@ -631,9 +565,11 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"Transformer2DModel",
"DownBlock2D",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
attention_head_dim = (8, 16)
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
@@ -269,37 +268,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
model_class_copy = copy.copy(UNetControlNetXSModel)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@is_flaky
def test_forward_no_control(self):

View File

@@ -161,27 +161,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
), "xformers is not enabled"
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
expected_set = {
"CrossAttnUpBlockMotion",
"CrossAttnDownBlockMotion",
"UNetMidBlockCrossAttnMotion",
@@ -189,9 +169,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
"Transformer2DModel",
"DownBlockMotion",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -25,7 +25,6 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_all_close,
torch_device,
)
@@ -160,47 +159,6 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
== "XFormersAttnProcessor"
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_num_attention_heads_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -239,30 +197,7 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
@@ -270,9 +205,10 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
num_attention_heads = (8, 16)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, num_attention_heads=num_attention_heads
)
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing