mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -109,7 +109,7 @@ def check_if_lora_correctly_set(model) -> bool:
|
||||
|
||||
|
||||
def normalize_output(out):
|
||||
out0 = out[0]
|
||||
out0 = out[0] if isinstance(out, tuple) else out
|
||||
return torch.stack(out0) if isinstance(out0, list) else out0
|
||||
|
||||
|
||||
@@ -541,10 +541,8 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image = torch.stack(image)
|
||||
if isinstance(new_image, list):
|
||||
new_image = torch.stack(new_image)
|
||||
image = normalize_output(image)
|
||||
new_image = normalize_output(new_image)
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
|
||||
@@ -790,10 +788,8 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image = torch.stack(image)
|
||||
if isinstance(new_image, list):
|
||||
new_image = torch.stack(new_image)
|
||||
image = normalize_output(image)
|
||||
new_image = normalize_output(new_image)
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
|
||||
@@ -857,10 +853,8 @@ class ModelTesterMixin:
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
if isinstance(first, list):
|
||||
first = torch.stack(first)
|
||||
if isinstance(second, list):
|
||||
second = torch.stack(second)
|
||||
first = normalize_output(first)
|
||||
second = normalize_output(second)
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
@@ -1349,15 +1343,16 @@ class ModelTesterMixin:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
# Offload check
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_normalized_out = normalize_output(new_output)
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_out, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
@@ -1680,8 +1675,7 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
base_slice = model(**inputs_dict)[0]
|
||||
if isinstance(base_slice, list):
|
||||
base_slice = torch.stack(base_slice)
|
||||
base_slice = normalize_output(base_slice)
|
||||
base_slice = base_slice.detach().flatten().cpu().numpy()
|
||||
|
||||
def check_linear_dtype(module, storage_dtype, compute_dtype):
|
||||
@@ -1709,8 +1703,7 @@ class ModelTesterMixin:
|
||||
|
||||
check_linear_dtype(model, storage_dtype, compute_dtype)
|
||||
output = model(**inputs_dict)[0]
|
||||
if isinstance(output, list):
|
||||
output = torch.stack(output)
|
||||
output = normalize_output(output)
|
||||
output = output.float().flatten().detach().cpu().numpy()
|
||||
|
||||
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
|
||||
@@ -1800,29 +1793,25 @@ class ModelTesterMixin:
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
if isinstance(output_without_group_offloading, list):
|
||||
output_without_group_offloading = torch.stack(output_without_group_offloading)
|
||||
output_without_group_offloading = normalize_output(output_without_group_offloading)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
if isinstance(output_with_group_offloading1, list):
|
||||
output_with_group_offloading1 = torch.stack(output_with_group_offloading1)
|
||||
output_with_group_offloading1 = normalize_output(output_with_group_offloading1)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
if isinstance(output_with_group_offloading2, list):
|
||||
output_with_group_offloading2 = torch.stack(output_with_group_offloading2)
|
||||
output_with_group_offloading2 = normalize_output(output_with_group_offloading2)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
if isinstance(output_with_group_offloading3, list):
|
||||
output_with_group_offloading3 = torch.stack(output_with_group_offloading3)
|
||||
output_with_group_offloading3 = normalize_output(output_with_group_offloading3)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -1830,8 +1819,7 @@ class ModelTesterMixin:
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
if isinstance(output_with_group_offloading4, list):
|
||||
output_with_group_offloading4 = torch.stack(output_with_group_offloading4)
|
||||
output_with_group_offloading4 = normalize_output(output_with_group_offloading4)
|
||||
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
|
||||
@@ -1902,8 +1890,7 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
if isinstance(output_without_group_offloading, list):
|
||||
output_without_group_offloading = torch.stack(output_without_group_offloading)
|
||||
output_without_group_offloading = normalize_output(output_without_group_offloading)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -1939,8 +1926,7 @@ class ModelTesterMixin:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
if isinstance(output_with_group_offloading, list):
|
||||
output_with_group_offloading = torch.stack(output_with_group_offloading)
|
||||
output_with_group_offloading = normalize_output(output_with_group_offloading)
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
|
||||
Reference in New Issue
Block a user