mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add test for group_offloading with training.
This commit is contained in:
@@ -1581,6 +1581,30 @@ class ModelTesterMixin:
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
|
||||
@parameterized.expand([False, True])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_with_training(self, use_stream):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=use_stream
|
||||
)
|
||||
model.train()
|
||||
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
|
||||
Reference in New Issue
Block a user