mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Enable Gradient Checkpointing for UNet2DModel (New) (#7201)
* Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -241,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder"}
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
|
||||
@@ -65,7 +65,7 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Encoder", "TemporalDecoder"}
|
||||
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Test unsupported.")
|
||||
|
||||
@@ -803,7 +803,7 @@ class ModelTesterMixin:
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
@require_torch_accelerator_with_training
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
|
||||
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
@@ -850,6 +850,8 @@ class ModelTesterMixin:
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
|
||||
|
||||
@@ -105,6 +105,23 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"AttnUpBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"UNetMidBlock2D",
|
||||
"UpBlock2D",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 8
|
||||
block_out_channels = (16, 32)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
@@ -220,6 +237,17 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 32
|
||||
block_out_channels = (32, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
@@ -329,3 +357,17 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
|
||||
block_out_channels = (32, 64, 64, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
|
||||
|
||||
Reference in New Issue
Block a user