1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-12-15 16:02:38 +05:30
parent d08e0bb545
commit eae7543712
10 changed files with 140 additions and 383 deletions

View File

@@ -1,9 +1,10 @@
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
from .attention import AttentionTesterMixin
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesTesterMixin,
GGUFTesterMixin,

View File

@@ -13,13 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
@@ -28,8 +24,6 @@ from diffusers.models.attention_processor import (
from ...testing_utils import (
assert_tensors_close,
is_attention,
is_context_parallel,
require_torch_multi_accelerator,
torch_device,
)
@@ -71,9 +65,7 @@ class AttentionTesterMixin:
# Get output before fusion
with torch.no_grad():
output_before_fusion = model(**inputs_dict)
if isinstance(output_before_fusion, dict):
output_before_fusion = output_before_fusion.to_tuple()[0]
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
# Fuse projections
model.fuse_qkv_projections()
@@ -90,9 +82,7 @@ class AttentionTesterMixin:
if has_fused_projections:
# Get output after fusion
with torch.no_grad():
output_after_fusion = model(**inputs_dict)
if isinstance(output_after_fusion, dict):
output_after_fusion = output_after_fusion.to_tuple()[0]
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
# Verify outputs match
assert_tensors_close(
@@ -115,9 +105,7 @@ class AttentionTesterMixin:
# Get output after unfusion
with torch.no_grad():
output_after_unfusion = model(**inputs_dict)
if isinstance(output_after_unfusion, dict):
output_after_unfusion = output_after_unfusion.to_tuple()[0]
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
# Verify outputs still match
assert_tensors_close(
@@ -195,80 +183,3 @@ class AttentionTesterMixin:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
# Setup distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device)
if isinstance(output, dict):
output = output.to_tuple()[0]
if rank == 0:
result_queue.put(("success", output.shape))
except Exception as e:
if rank == 0:
result_queue.put(("error", str(e)))
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
base_precision = 1e-3
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
cp_dict = {cp_type: world_size}
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
mp.spawn(
_context_parallel_worker,
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
nprocs=world_size,
join=True,
)
status, result = result_queue.get(timeout=60)
assert status == "success", f"Context parallel inference failed: {result}"

View File

@@ -259,7 +259,7 @@ class ModelTesterMixin:
pass
"""
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0):
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
@@ -278,15 +278,8 @@ class ModelTesterMixin:
)
with torch.no_grad():
image = model(**self.get_dummy_inputs())
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**self.get_dummy_inputs())
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -308,14 +301,8 @@ class ModelTesterMixin:
new_model.to(torch_device)
with torch.no_grad():
image = model(**self.get_dummy_inputs())
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**self.get_dummy_inputs())
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -343,13 +330,8 @@ class ModelTesterMixin:
model.eval()
with torch.no_grad():
first = model(**self.get_dummy_inputs())
if isinstance(first, dict):
first = first.to_tuple()[0]
second = model(**self.get_dummy_inputs())
if isinstance(second, dict):
second = second.to_tuple()[0]
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
# Filter out NaN values before comparison
first_flat = first.flatten()
@@ -369,10 +351,7 @@ class ModelTesterMixin:
inputs_dict = self.get_dummy_inputs()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
@@ -501,13 +480,8 @@ class ModelTesterMixin:
assert param.data.dtype == dtype
with torch.no_grad():
output = model(**self.get_dummy_inputs())
if isinstance(output, dict):
output = output.to_tuple()[0]
output_loaded = model_loaded(**self.get_dummy_inputs())
if isinstance(output_loaded, dict):
output_loaded = output_loaded.to_tuple()[0]
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
@@ -519,7 +493,7 @@ class ModelTesterMixin:
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -539,10 +513,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
)
@require_accelerator
@@ -553,7 +527,7 @@ class ModelTesterMixin:
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -578,10 +552,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
)
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
@@ -593,7 +567,7 @@ class ModelTesterMixin:
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -628,10 +602,10 @@ class ModelTesterMixin:
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel)
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading"
base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading"
)
finally:
@@ -652,7 +626,7 @@ class ModelTesterMixin:
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
@@ -668,8 +642,8 @@ class ModelTesterMixin:
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism"
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism"
)

View File

@@ -1,109 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import uuid
import pytest
import torch
from huggingface_hub.utils import ModelCard, delete_repo, is_jinja_available
from ...others.test_utils import TOKEN, USER, is_staging_test
@is_staging_test
class ModelPushToHubTesterMixin:
"""
Mixin class for testing push_to_hub functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
"""
identifier = uuid.uuid4()
repo_id = f"test-model-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def test_push_to_hub(self):
"""Test pushing model to hub and loading it back."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.repo_id, token=TOKEN)
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), (
"Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
)
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
def test_push_to_hub_in_organization(self):
"""Test pushing model to hub in organization namespace."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.org_repo_id, token=TOKEN)
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), (
"Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
)
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
def test_push_to_hub_library_name(self):
"""Test that library_name in model card is set to 'diffusers'."""
if not is_jinja_available():
pytest.skip("Model card tests cannot be performed without Jinja installed.")
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.repo_id, token=TOKEN)
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert model_card.library_name == "diffusers", (
f"Expected library_name 'diffusers', got {model_card.library_name}"
)
# Reset repo
delete_repo(self.repo_id, token=TOKEN)

View File

@@ -17,49 +17,9 @@
import pytest
import torch
from diffusers.models.attention_processor import IPAdapterAttnProcessor
from ...testing_utils import is_ip_adapter, torch_device
def create_ip_adapter_state_dict(model):
"""
Create a dummy IP Adapter state dict for testing.
Args:
model: The model to create IP adapter weights for
Returns:
dict: IP adapter state dict with to_k_ip and to_v_ip weights
"""
ip_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
# Skip self-attention processors
cross_attention_dim = getattr(model.config, "cross_attention_dim", None)
if cross_attention_dim is None:
continue
# Get hidden size based on model architecture
hidden_size = getattr(model.config, "hidden_size", cross_attention_dim)
# Create IP adapter processor to get state dict structure
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
key_id += 2
return {"ip_adapter": ip_state_dict}
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.

View File

@@ -79,8 +79,6 @@ class LoraTesterMixin:
"""
def setup_method(self):
from diffusers.loaders.peft import PeftAdapterMixin
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")

View File

@@ -455,10 +455,7 @@ class LayerwiseCastingTesterMixin:
inputs_dict = self.get_inputs_dict()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
output = model(**inputs_dict, return_dict=False)[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)

View File

@@ -128,9 +128,9 @@ class QuantizationTesterMixin:
model_quantized = self._create_quantized_model(config_kwargs)
num_params_quantized = model_quantized.num_parameters()
assert num_params == num_params_quantized, (
f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
)
assert (
num_params == num_params_quantized
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
model = self._load_unquantized_model()
@@ -140,19 +140,17 @@ class QuantizationTesterMixin:
mem_quantized = model_quantized.get_memory_footprint()
ratio = mem / mem_quantized
assert ratio >= expected_memory_reduction, (
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
)
assert (
ratio >= expected_memory_reduction
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_quantized(**inputs)
output = model_quantized(**inputs, return_dict=False)[0]
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
@@ -197,10 +195,8 @@ class QuantizationTesterMixin:
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs)
output = model(**inputs, return_dict=False)[0]
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None with LoRA"
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
@@ -214,9 +210,7 @@ class QuantizationTesterMixin:
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_loaded(**inputs)
if isinstance(output, tuple):
output = output[0]
output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
def _test_quantized_layers(self, config_kwargs):
@@ -243,12 +237,12 @@ class QuantizationTesterMixin:
self._verify_if_layer_quantized(name, module, config_kwargs)
num_quantized_layers += 1
assert num_quantized_layers > 0, (
f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
)
assert num_quantized_layers == expected_quantized_layers, (
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
assert (
num_quantized_layers > 0
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
assert (
num_quantized_layers == expected_quantized_layers
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
@@ -272,9 +266,9 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert not self._is_module_quantized(
module
), f"Module {name} should not be quantized but was found to be quantized"
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
@@ -296,9 +290,9 @@ class QuantizationTesterMixin:
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
assert mem_with_exclusion > mem_fully_quantized, (
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
)
assert (
mem_with_exclusion > mem_fully_quantized
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
def _test_quantization_device_map(self, config_kwargs):
"""
@@ -316,12 +310,38 @@ class QuantizationTesterMixin:
# Verify inference works
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs)
if isinstance(output, tuple):
output = output[0]
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
def _test_dequantize(self, config_kwargs):
"""
Test that dequantize() converts quantized model back to standard linear layers.
Args:
config_kwargs: Quantization config parameters
"""
model = self._create_quantized_model(config_kwargs)
# Verify model has dequantize method
if not hasattr(model, "dequantize"):
pytest.skip("Model does not have dequantize method")
# Dequantize the model
model.dequantize()
# Verify no modules are quantized after dequantization
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
# Verify inference still works after dequantization
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
@is_bitsandbytes
@require_accelerator
@@ -379,9 +399,9 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin):
def _verify_if_layer_quantized(self, name, module, config_kwargs):
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
assert module.weight.__class__ == expected_weight_class, (
f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
)
assert (
module.weight.__class__ == expected_weight_class
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_num_parameters(self, config_name):
@@ -449,13 +469,13 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
assert (
module.weight.dtype == torch.float32
), f"Module {name} should be FP32 but is {module.weight.dtype}"
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
assert (
module.weight.dtype == torch.uint8
), f"Module {name} should be uint8 but is {module.weight.dtype}"
with torch.no_grad():
inputs = self.get_dummy_inputs()
@@ -476,6 +496,10 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
def test_bnb_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(self.BNB_CONFIGS["4bit_nf4"])
@is_quanto
@require_quanto
@@ -563,6 +587,10 @@ class QuantoTesterMixin(QuantizationTesterMixin):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"])
def test_quanto_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(self.QUANTO_WEIGHT_TYPES["int8"])
@is_torchao
@require_accelerator
@@ -649,6 +677,10 @@ class TorchAoTesterMixin(QuantizationTesterMixin):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"])
def test_torchao_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(self.TORCHAO_QUANT_TYPES["int8wo"])
@is_gguf
@require_accelerate
@@ -716,24 +748,9 @@ class GGUFTesterMixin(QuantizationTesterMixin):
def test_gguf_quantization_lora_inference(self):
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
def test_gguf_dequantize_model(self):
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
model = self._create_quantized_model()
model.dequantize()
def _check_for_gguf_linear(model):
has_children = list(model.children())
if not has_children:
return
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
for name, module in model.named_children():
_check_for_gguf_linear(module)
def test_gguf_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize({"compute_dtype": torch.bfloat16})
def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
@@ -826,3 +843,7 @@ class ModelOptTesterMixin(QuantizationTesterMixin):
def test_modelopt_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"])
def test_modelopt_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(self.MODELOPT_CONFIGS["fp8"])

View File

@@ -50,10 +50,7 @@ class TrainingTesterMixin:
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
@@ -68,10 +65,7 @@ class TrainingTesterMixin:
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
@@ -137,9 +131,7 @@ class TrainingTesterMixin:
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict)
if isinstance(out, dict):
out = out.sample if hasattr(out, "sample") else out.to_tuple()[0]
out = model(**inputs_dict, return_dict=False)[0]
# run the backwards pass on the model
model.zero_grad()
@@ -158,9 +150,7 @@ class TrainingTesterMixin:
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy)
if isinstance(out_2, dict):
out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0]
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
# run the backwards pass on the model
model_2.zero_grad()
@@ -198,10 +188,7 @@ class TrainingTesterMixin:
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
@@ -212,10 +199,7 @@ class TrainingTesterMixin:
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)

View File

@@ -43,10 +43,15 @@ ATTRIBUTE_TO_TESTER = {
ALWAYS_INCLUDE_TESTERS = [
"ModelTesterMixin",
"MemoryTesterMixin",
"AttentionTesterMixin",
"TorchCompileTesterMixin",
]
# Attention-related class names that indicate the model uses attention
ATTENTION_INDICATORS = {
"AttentionMixin",
"AttentionModuleMixin",
}
OPTIONAL_TESTERS = [
("BitsAndBytesTesterMixin", "bnb"),
("QuantoTesterMixin", "quanto"),
@@ -62,6 +67,17 @@ class ModelAnalyzer(ast.NodeVisitor):
def __init__(self):
self.model_classes = []
self.current_class = None
self.imports = set()
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.imports.add(alias.name.split(".")[-1])
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
self.imports.add(alias.name)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef):
base_names = []
@@ -164,7 +180,7 @@ class ModelAnalyzer(ast.NodeVisitor):
return "<complex>"
def analyze_model_file(filepath: str) -> list[dict]:
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
with open(filepath) as f:
source = f.read()
@@ -172,10 +188,10 @@ def analyze_model_file(filepath: str) -> list[dict]:
analyzer = ModelAnalyzer()
analyzer.visit(tree)
return analyzer.model_classes
return analyzer.model_classes, analyzer.imports
def determine_testers(model_info: dict, include_optional: list[str]) -> list[str]:
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
testers = list(ALWAYS_INCLUDE_TESTERS)
for base in model_info["bases"]:
@@ -195,6 +211,10 @@ def determine_testers(model_info: dict, include_optional: list[str]) -> list[str
if "ContextParallelTesterMixin" not in testers:
testers.append("ContextParallelTesterMixin")
# Include AttentionTesterMixin if the model imports attention-related classes
if imports & ATTENTION_INDICATORS:
testers.append("AttentionTesterMixin")
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
@@ -335,9 +355,9 @@ def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
return "\n".join(lines)
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str]) -> str:
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
testers = determine_testers(model_info, include_optional)
testers = determine_testers(model_info, include_optional, imports)
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
lines = [
@@ -446,7 +466,7 @@ def main():
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
sys.exit(1)
model_classes = analyze_model_file(args.model_filepath)
model_classes, imports = analyze_model_file(args.model_filepath)
if not model_classes:
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
@@ -468,7 +488,7 @@ def main():
if "all" in include_optional:
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
generated_code = generate_test_file(model_info, args.model_filepath, include_optional)
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
if args.dry_run:
print(generated_code)