mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
305 lines
12 KiB
Python
305 lines
12 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import tempfile
|
|
from typing import Dict, List, Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from ...testing_utils import torch_device
|
|
|
|
|
|
class ModelTesterMixin:
|
|
"""
|
|
Base mixin class for model testing with common test methods.
|
|
|
|
Expected class attributes to be set by subclasses:
|
|
- model_class: The model class to test
|
|
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
|
|
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
"""
|
|
|
|
model_class = None
|
|
base_precision = 1e-3
|
|
|
|
def get_init_dict(self):
|
|
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
|
|
|
|
def get_dummy_inputs(self):
|
|
raise NotImplementedError(
|
|
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
|
|
)
|
|
|
|
def check_device_map_is_respected(self, model, device_map):
|
|
"""Helper method to check if device map is correctly applied to model parameters."""
|
|
for param_name, param in model.named_parameters():
|
|
# Find device in device_map
|
|
while len(param_name) > 0 and param_name not in device_map:
|
|
param_name = ".".join(param_name.split(".")[:-1])
|
|
if param_name not in device_map:
|
|
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
|
|
|
param_device = device_map[param_name]
|
|
if param_device in ["cpu", "disk"]:
|
|
assert param.device == torch.device(
|
|
"meta"
|
|
), f"Expected device 'meta' for {param_name}, got {param.device}"
|
|
else:
|
|
assert param.device == torch.device(
|
|
param_device
|
|
), f"Expected device {param_device} for {param_name}, got {param.device}"
|
|
|
|
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
|
"""Test that model can be saved and loaded with save_pretrained/from_pretrained."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
new_model = self.model_class.from_pretrained(tmpdirname)
|
|
new_model.to(torch_device)
|
|
|
|
with torch.no_grad():
|
|
image = model(**self.get_dummy_inputs())
|
|
|
|
if isinstance(image, dict):
|
|
image = image.to_tuple()[0]
|
|
|
|
new_image = new_model(**self.get_dummy_inputs())
|
|
|
|
if isinstance(new_image, dict):
|
|
new_image = new_image.to_tuple()[0]
|
|
|
|
max_diff = (image - new_image).abs().max().item()
|
|
assert (
|
|
max_diff <= expected_max_diff
|
|
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
|
|
|
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
|
"""Test save_pretrained/from_pretrained with variant parameter."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname, variant="fp16")
|
|
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
|
|
|
# non-variant cannot be loaded
|
|
with pytest.raises(OSError) as exc_info:
|
|
self.model_class.from_pretrained(tmpdirname)
|
|
|
|
# make sure that error message states what keys are missing
|
|
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
|
|
|
new_model.to(torch_device)
|
|
|
|
with torch.no_grad():
|
|
image = model(**self.get_dummy_inputs())
|
|
if isinstance(image, dict):
|
|
image = image.to_tuple()[0]
|
|
|
|
new_image = new_model(**self.get_dummy_inputs())
|
|
|
|
if isinstance(new_image, dict):
|
|
new_image = new_image.to_tuple()[0]
|
|
|
|
max_diff = (image - new_image).abs().max().item()
|
|
assert (
|
|
max_diff <= expected_max_diff
|
|
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
|
|
|
def test_from_save_pretrained_dtype(self):
|
|
"""Test save_pretrained/from_pretrained preserves dtype correctly."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
if torch_device == "mps" and dtype == torch.bfloat16:
|
|
continue
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.to(dtype)
|
|
model.save_pretrained(tmpdirname)
|
|
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
|
assert new_model.dtype == dtype
|
|
if (
|
|
hasattr(self.model_class, "_keep_in_fp32_modules")
|
|
and self.model_class._keep_in_fp32_modules is None
|
|
):
|
|
new_model = self.model_class.from_pretrained(
|
|
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
|
)
|
|
assert new_model.dtype == dtype
|
|
|
|
def test_determinism(self, expected_max_diff=1e-5):
|
|
"""Test that model outputs are deterministic across multiple forward passes."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
first = model(**self.get_dummy_inputs())
|
|
if isinstance(first, dict):
|
|
first = first.to_tuple()[0]
|
|
|
|
second = model(**self.get_dummy_inputs())
|
|
if isinstance(second, dict):
|
|
second = second.to_tuple()[0]
|
|
|
|
# Remove NaN values and compute max difference
|
|
first_flat = first.flatten()
|
|
second_flat = second.flatten()
|
|
|
|
# Filter out NaN values
|
|
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
|
first_filtered = first_flat[mask]
|
|
second_filtered = second_flat[mask]
|
|
|
|
max_diff = torch.abs(first_filtered - second_filtered).max().item()
|
|
assert (
|
|
max_diff <= expected_max_diff
|
|
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
|
|
|
|
def test_output(self, expected_output_shape=None):
|
|
"""Test that model produces output with expected shape."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
inputs_dict = self.get_dummy_inputs()
|
|
with torch.no_grad():
|
|
output = model(**inputs_dict)
|
|
|
|
if isinstance(output, dict):
|
|
output = output.to_tuple()[0]
|
|
|
|
assert output is not None, "Model output is None"
|
|
assert (
|
|
output.shape == expected_output_shape
|
|
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
|
|
|
def test_model_from_pretrained(self):
|
|
"""Test that model loaded from pretrained matches original model."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# test if the model can be loaded from the config
|
|
# and has all the expected shape
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname, safe_serialization=False)
|
|
new_model = self.model_class.from_pretrained(tmpdirname)
|
|
new_model.to(torch_device)
|
|
new_model.eval()
|
|
|
|
# check if all parameters shape are the same
|
|
for param_name in model.state_dict().keys():
|
|
param_1 = model.state_dict()[param_name]
|
|
param_2 = new_model.state_dict()[param_name]
|
|
assert (
|
|
param_1.shape == param_2.shape
|
|
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
|
|
|
with torch.no_grad():
|
|
output_1 = model(**self.get_dummy_inputs())
|
|
|
|
if isinstance(output_1, dict):
|
|
output_1 = output_1.to_tuple()[0]
|
|
|
|
output_2 = new_model(**self.get_dummy_inputs())
|
|
|
|
if isinstance(output_2, dict):
|
|
output_2 = output_2.to_tuple()[0]
|
|
|
|
assert (
|
|
output_1.shape == output_2.shape
|
|
), f"Output shape mismatch. Original: {output_1.shape}, loaded: {output_2.shape}"
|
|
|
|
def test_outputs_equivalence(self):
|
|
"""Test that dict and tuple outputs are equivalent."""
|
|
|
|
def set_nan_tensor_to_zero(t):
|
|
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
|
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
|
device = t.device
|
|
if device.type == "mps":
|
|
t = t.to("cpu")
|
|
t[t != t] = 0
|
|
return t.to(device)
|
|
|
|
def recursive_check(tuple_object, dict_object):
|
|
if isinstance(tuple_object, (List, Tuple)):
|
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
elif isinstance(tuple_object, Dict):
|
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
elif tuple_object is None:
|
|
return
|
|
else:
|
|
assert torch.allclose(
|
|
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
|
), (
|
|
"Tuple and dict output are not equal. Difference:"
|
|
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
|
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
|
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
|
)
|
|
|
|
model = self.model_class(**self.get_init_dict())
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
outputs_dict = model(**self.get_dummy_inputs())
|
|
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
|
|
|
recursive_check(outputs_tuple, outputs_dict)
|
|
|
|
def test_model_config_to_json_string(self):
|
|
"""Test model config can be serialized to JSON string."""
|
|
model = self.model_class(**self.get_init_dict())
|
|
|
|
json_string = model.config.to_json_string()
|
|
assert isinstance(json_string, str), "Config to_json_string should return a string"
|
|
assert len(json_string) > 0, "JSON string should not be empty"
|
|
|
|
def test_keep_in_fp32_modules(self):
|
|
r"""
|
|
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
|
|
Also ensures if inference works.
|
|
"""
|
|
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
|
pytest.skip("Model does not have _keep_in_fp32_modules")
|
|
|
|
fp32_modules = self.model_class._keep_in_fp32_modules
|
|
|
|
for torch_dtype in [torch.bfloat16, torch.float16]:
|
|
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, torch_dtype=torch_dtype).to(
|
|
torch_device
|
|
)
|
|
for name, param in model.named_parameters():
|
|
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
|
assert param.data == torch.float32
|
|
else:
|
|
assert param.data == torch_dtype
|