1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add test for group_offloading with training.

This commit is contained in:
sayakpaul
2025-05-07 14:37:35 +05:30
parent fb29132b98
commit 131ed8ed16

View File

@@ -1581,6 +1581,30 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
@parameterized.expand([False, True])
@require_torch_accelerator
def test_group_offloading_with_training(self, use_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.enable_group_offload(
torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=use_stream
)
model.train()
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)