mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix gradient checkpointing test (#797)
* Fix gradient checkpointing test * more tsets
This commit is contained in:
committed by
GitHub
parent
fab17528da
commit
22963ed826
@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_not_checkpointed = out.data.clone()
|
||||
grad_not_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_not_checkpointed[name] = param.grad.data.clone()
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
out = model(**inputs_dict).sample
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_checkpointed = out.data.clone()
|
||||
grad_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_checkpointed[name] = param.grad.data.clone()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
|
||||
for name in grad_checkpointed:
|
||||
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
|
||||
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
||||
|
||||
Reference in New Issue
Block a user