1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-11-28 22:07:17 +05:30
parent a74a8f7885
commit c137ae1cda

View File

@@ -109,7 +109,7 @@ def check_if_lora_correctly_set(model) -> bool:
def normalize_output(out):
out0 = out[0]
out0 = out[0] if isinstance(out, tuple) else out
return torch.stack(out0) if isinstance(out0, list) else out0
@@ -541,10 +541,8 @@ class ModelTesterMixin:
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
if isinstance(image, list):
image = torch.stack(image)
if isinstance(new_image, list):
new_image = torch.stack(new_image)
image = normalize_output(image)
new_image = normalize_output(new_image)
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
@@ -790,10 +788,8 @@ class ModelTesterMixin:
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
if isinstance(image, list):
image = torch.stack(image)
if isinstance(new_image, list):
new_image = torch.stack(new_image)
image = normalize_output(image)
new_image = normalize_output(new_image)
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
@@ -857,10 +853,8 @@ class ModelTesterMixin:
if isinstance(second, dict):
second = second.to_tuple()[0]
if isinstance(first, list):
first = torch.stack(first)
if isinstance(second, list):
second = torch.stack(second)
first = normalize_output(first)
second = normalize_output(second)
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
@@ -1349,15 +1343,16 @@ class ModelTesterMixin:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Offload check
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_normalized_out = normalize_output(new_output)
new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_out, atol=1e-5))
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator
def test_disk_offload_without_safetensors(self):
@@ -1680,8 +1675,7 @@ class ModelTesterMixin:
model.eval()
model.to(torch_device)
base_slice = model(**inputs_dict)[0]
if isinstance(base_slice, list):
base_slice = torch.stack(base_slice)
base_slice = normalize_output(base_slice)
base_slice = base_slice.detach().flatten().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype):
@@ -1709,8 +1703,7 @@ class ModelTesterMixin:
check_linear_dtype(model, storage_dtype, compute_dtype)
output = model(**inputs_dict)[0]
if isinstance(output, list):
output = torch.stack(output)
output = normalize_output(output)
output = output.float().flatten().detach().cpu().numpy()
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
@@ -1800,29 +1793,25 @@ class ModelTesterMixin:
model.to(torch_device)
output_without_group_offloading = run_forward(model)
if isinstance(output_without_group_offloading, list):
output_without_group_offloading = torch.stack(output_without_group_offloading)
output_without_group_offloading = normalize_output(output_without_group_offloading)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
if isinstance(output_with_group_offloading1, list):
output_with_group_offloading1 = torch.stack(output_with_group_offloading1)
output_with_group_offloading1 = normalize_output(output_with_group_offloading1)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
if isinstance(output_with_group_offloading2, list):
output_with_group_offloading2 = torch.stack(output_with_group_offloading2)
output_with_group_offloading2 = normalize_output(output_with_group_offloading2)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
if isinstance(output_with_group_offloading3, list):
output_with_group_offloading3 = torch.stack(output_with_group_offloading3)
output_with_group_offloading3 = normalize_output(output_with_group_offloading3)
torch.manual_seed(0)
model = self.model_class(**init_dict)
@@ -1830,8 +1819,7 @@ class ModelTesterMixin:
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
if isinstance(output_with_group_offloading4, list):
output_with_group_offloading4 = torch.stack(output_with_group_offloading4)
output_with_group_offloading4 = normalize_output(output_with_group_offloading4)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
@@ -1902,8 +1890,7 @@ class ModelTesterMixin:
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
if isinstance(output_without_group_offloading, list):
output_without_group_offloading = torch.stack(output_without_group_offloading)
output_without_group_offloading = normalize_output(output_without_group_offloading)
torch.manual_seed(0)
model = self.model_class(**init_dict)
@@ -1939,8 +1926,7 @@ class ModelTesterMixin:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
if isinstance(output_with_group_offloading, list):
output_with_group_offloading = torch.stack(output_with_group_offloading)
output_with_group_offloading = normalize_output(output_with_group_offloading)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
def test_auto_model(self, expected_max_diff=5e-5):