diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index d05af686de..bec62ce5cf 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): conditioning with `class_embed_type` equal to `None`. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -241,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin): self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index b9d186ac1a..b4e0cea7c7 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module): else: self.downsamplers = None + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.Tensor, @@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module): output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, **cross_attention_kwargs) - output_states = output_states + (hidden_states,) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module): else: self.upsamplers = None + self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( @@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module): res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn(hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 107a5a45bf..0fd8875a88 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states, temb=temb) - hidden_states = resnet(hidden_states, temb) + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 52bf5aba20..c584bdcf56 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -146,7 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ) def test_gradient_checkpointing_is_applied(self): - expected_set = {"Decoder", "Encoder"} + expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_from_pretrained_hub(self): diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 4308cb6489..cf80ff5044 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -65,7 +65,7 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt return init_dict, inputs_dict def test_gradient_checkpointing_is_applied(self): - expected_set = {"Encoder", "TemporalDecoder"} + expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @unittest.skip("Test unsupported.") diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a7594f2ea1..91a462d587 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -803,7 +803,7 @@ class ModelTesterMixin: self.assertFalse(model.is_gradient_checkpointing) @require_torch_accelerator_with_training - def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5): + def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}): if not self.model_class._supports_gradient_checkpointing: return # Skip test if model does not support gradient checkpointing @@ -850,6 +850,8 @@ class ModelTesterMixin: for name, param in named_params.items(): if "post_quant_conv" in name: continue + if name in skip: + continue self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)) @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.") diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 5f827f2742..ddf5f53511 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -105,6 +105,23 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "AttnUpBlock2D", + "AttnDownBlock2D", + "UNetMidBlock2D", + "UpBlock2D", + "DownBlock2D", + } + + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + attention_head_dim = 8 + block_out_channels = (16, 32) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -220,6 +237,17 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + + # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` + attention_head_dim = 32 + block_out_channels = (32, 64) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) + class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -329,3 +357,17 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): def test_forward_with_norm_groups(self): # not required for this model pass + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "UNetMidBlock2D", + } + + block_out_channels = (32, 64, 64, 64) + + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, block_out_channels=block_out_channels + ) + + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})