mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
|
||||
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
|
||||
"Model parameters don't match!"
|
||||
)
|
||||
assert all(
|
||||
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
|
||||
), "Model parameters don't match!"
|
||||
|
||||
# Remove a shard file
|
||||
cached_shard_file = try_to_load_from_cache(
|
||||
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
# Verify error mentions the missing shard
|
||||
error_msg = str(context.exception)
|
||||
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
|
||||
f"Expected error about missing shard, got: {error_msg}"
|
||||
)
|
||||
assert (
|
||||
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
|
||||
), f"Expected error about missing shard, got: {error_msg}"
|
||||
|
||||
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 3, (
|
||||
"3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
)
|
||||
assert (
|
||||
download_requests.count("HEAD") == 3
|
||||
), "3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
)
|
||||
assert (
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
|
||||
2
tests/models/testing_utils/__init__.py
Normal file
2
tests/models/testing_utils/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .common import ModelTesterMixin
|
||||
from .single_file import SingleFileTesterMixin
|
||||
0
tests/models/testing_utils/attention.py
Normal file
0
tests/models/testing_utils/attention.py
Normal file
304
tests/models/testing_utils/common.py
Normal file
304
tests/models/testing_utils/common.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
model_class = None
|
||||
base_precision = 1e-3
|
||||
|
||||
def get_init_dict(self):
|
||||
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
raise NotImplementedError(
|
||||
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
|
||||
)
|
||||
|
||||
def check_device_map_is_respected(self, model, device_map):
|
||||
"""Helper method to check if device map is correctly applied to model parameters."""
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device(
|
||||
"meta"
|
||||
), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(
|
||||
param_device
|
||||
), f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
"""Test that model can be saved and loaded with save_pretrained/from_pretrained."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
||||
"""Test save_pretrained/from_pretrained with variant parameter."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
"""Test save_pretrained/from_pretrained preserves dtype correctly."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if (
|
||||
hasattr(self.model_class, "_keep_in_fp32_modules")
|
||||
and self.model_class._keep_in_fp32_modules is None
|
||||
):
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
||||
)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
"""Test that model outputs are deterministic across multiple forward passes."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**self.get_dummy_inputs())
|
||||
if isinstance(first, dict):
|
||||
first = first.to_tuple()[0]
|
||||
|
||||
second = model(**self.get_dummy_inputs())
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
# Remove NaN values and compute max difference
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
|
||||
# Filter out NaN values
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
max_diff = torch.abs(first_filtered - second_filtered).max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
"""Test that model produces output with expected shape."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert (
|
||||
output.shape == expected_output_shape
|
||||
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
"""Test that model loaded from pretrained matches original model."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# test if the model can be loaded from the config
|
||||
# and has all the expected shape
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
new_model.eval()
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert (
|
||||
param_1.shape == param_2.shape
|
||||
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(output_1, dict):
|
||||
output_1 = output_1.to_tuple()[0]
|
||||
|
||||
output_2 = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(output_2, dict):
|
||||
output_2 = output_2.to_tuple()[0]
|
||||
|
||||
assert (
|
||||
output_1.shape == output_2.shape
|
||||
), f"Output shape mismatch. Original: {output_1.shape}, loaded: {output_2.shape}"
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
"""Test that dict and tuple outputs are equivalent."""
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
), (
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_model_config_to_json_string(self):
|
||||
"""Test model config can be serialized to JSON string."""
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
|
||||
json_string = model.config.to_json_string()
|
||||
assert isinstance(json_string, str), "Config to_json_string should return a string"
|
||||
assert len(json_string) > 0, "JSON string should not be empty"
|
||||
|
||||
def test_keep_in_fp32_modules(self):
|
||||
r"""
|
||||
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
|
||||
Also ensures if inference works.
|
||||
"""
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip("Model does not have _keep_in_fp32_modules")
|
||||
|
||||
fp32_modules = self.model_class._keep_in_fp32_modules
|
||||
|
||||
for torch_dtype in [torch.bfloat16, torch.float16]:
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, torch_dtype=torch_dtype).to(
|
||||
torch_device
|
||||
)
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data == torch.float32
|
||||
else:
|
||||
assert param.data == torch_dtype
|
||||
166
tests/models/testing_utils/compile.py
Normal file
166
tests/models/testing_utils/compile.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup before each test method."""
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test method."""
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
"""Test that model compiles without graph breaks and doesn't recompile unnecessarily."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
"""Test compilation of repeated blocks if model supports it."""
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
"""Test that compilation works with group offloading enabled."""
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_on_different_shapes(self):
|
||||
"""Test dynamic compilation on different input shapes."""
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self):
|
||||
"""Test that model works with ahead-of-time compilation and packaging."""
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
110
tests/models/testing_utils/hub.py
Normal file
110
tests/models/testing_utils/hub.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
|
||||
from ...others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTesterMixin:
|
||||
"""
|
||||
Mixin class for testing push_to_hub functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
"""
|
||||
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-model-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def test_push_to_hub(self):
|
||||
"""Test pushing model to hub and loading it back."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
"""Test pushing model to hub in organization namespace."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_library_name(self):
|
||||
"""Test that library_name in model card is set to 'diffusers'."""
|
||||
if not is_jinja_available():
|
||||
pytest.skip("Model card tests cannot be performed without Jinja installed.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
|
||||
assert (
|
||||
model_card.library_name == "diffusers"
|
||||
), f"Expected library_name 'diffusers', got {model_card.library_name}"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
0
tests/models/testing_utils/ip_adapter.py
Normal file
0
tests/models/testing_utils/ip_adapter.py
Normal file
0
tests/models/testing_utils/lora.py
Normal file
0
tests/models/testing_utils/lora.py
Normal file
0
tests/models/testing_utils/offloading.py
Normal file
0
tests/models/testing_utils/offloading.py
Normal file
252
tests/models/testing_utils/single_file.py
Normal file
252
tests/models/testing_utils/single_file.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
- subfolder: (Optional) Subfolder within the repo
|
||||
- torch_dtype: (Optional) torch dtype to use for testing
|
||||
"""
|
||||
|
||||
pretrained_model_name_or_path = None
|
||||
ckpt_path = None
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup before each test method."""
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test method."""
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
"""Test that config matches between pretrained and single file loading."""
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
"""Test that parameters match between pretrained and single file loading."""
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_single_file = model_single_file.state_dict()
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
|
||||
f"Parameter values differ for {key}: "
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self):
|
||||
"""Test single file loading with local_files_only=True."""
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
"""Test single file loading with diffusers config."""
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {}
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self):
|
||||
"""Test single file loading with diffusers config and local_files_only=True."""
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
"""Test single file loading with different dtypes."""
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
"""Test loading checkpoints with alternate keys/variants if provided."""
|
||||
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
0
tests/models/testing_utils/training.py
Normal file
0
tests/models/testing_utils/training.py
Normal file
154
tests/models/transformers/test_models_transformer_flux_.py
Normal file
154
tests/models/transformers/test_models_transformer_flux_.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils.common import ModelTesterMixin
|
||||
from ..testing_utils.compile import TorchCompileTesterMixin
|
||||
from ..testing_utils.single_file import SingleFileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig:
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
def get_init_dict(self):
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 32,
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
)
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||
subfolder = "transformer"
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
Reference in New Issue
Block a user