1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-12-11 11:04:47 +05:30
parent 0f1a4e0c14
commit fe451c367b
16 changed files with 679 additions and 706 deletions

View File

@@ -46,6 +46,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
def pytest_addoption(parser):

View File

@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
assert all(
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
), "Model parameters don't match!"
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
"Model parameters don't match!"
)
# Remove a shard file
cached_shard_file = try_to_load_from_cache(
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
# Verify error mentions the missing shard
error_msg = str(context.exception)
assert (
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
), f"Expected error about missing shard, got: {error_msg}"
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
f"Expected error about missing shard, got: {error_msg}"
)
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
)
download_requests = [r.method for r in m.request_history]
assert (
download_requests.count("HEAD") == 3
), "3 HEAD requests one for config, one for model, and one for shard index file."
assert download_requests.count("HEAD") == 3, (
"3 HEAD requests one for config, one for model, and one for shard index file."
)
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
)
cache_requests = [r.method for r in m.request_history]
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 2
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
)
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:

View File

@@ -1,4 +1,4 @@
from .attention import AttentionTesterMixin
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
from .common import ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
@@ -17,6 +17,7 @@ from .training import TrainingTesterMixin
__all__ = [
"ContextParallelTesterMixin",
"AttentionTesterMixin",
"BitsAndBytesTesterMixin",
"CPUOffloadTesterMixin",

View File

@@ -13,15 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import is_attention, torch_device
from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device
@is_attention
@@ -85,9 +89,9 @@ class AttentionTesterMixin:
output_after_fusion = output_after_fusion.to_tuple()[0]
# Verify outputs match
assert torch.allclose(
output_before_fusion, output_after_fusion, atol=self.base_precision
), "Output should not change after fusing projections"
assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), (
"Output should not change after fusing projections"
)
# Unfuse projections
model.unfuse_qkv_projections()
@@ -106,9 +110,9 @@ class AttentionTesterMixin:
output_after_unfusion = output_after_unfusion.to_tuple()[0]
# Verify outputs still match
assert torch.allclose(
output_before_fusion, output_after_unfusion, atol=self.base_precision
), "Output should match original after unfusing projections"
assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), (
"Output should match original after unfusing projections"
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
@@ -177,3 +181,83 @@ class AttentionTesterMixin:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
# Setup distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device)
if isinstance(output, dict):
output = output.to_tuple()[0]
if rank == 0:
result_queue.put(("success", output.shape))
except Exception as e:
if rank == 0:
result_queue.put(("error", str(e)))
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
base_precision = 1e-3
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Context parallel requires at least 2 CUDA devices.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
cp_dict = {cp_type: world_size}
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
mp.spawn(
_context_parallel_worker,
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
nprocs=world_size,
join=True,
)
status, result = result_queue.get(timeout=60)
assert status == "success", f"Context parallel inference failed: {result}"

View File

@@ -16,16 +16,45 @@
import json
import os
import tempfile
from collections import defaultdict
import pytest
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import torch_device
from ...testing_utils import CaptureLogger, torch_device
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
@@ -96,9 +125,9 @@ def check_device_map_is_respected(model, device_map):
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(
param_device
), f"Expected device {param_device} for {param_name}, got {param.device}"
assert param.device == torch.device(param_device), (
f"Expected device {param_device} for {param_name}, got {param.device}"
)
class ModelTesterMixin:
@@ -123,9 +152,7 @@ class ModelTesterMixin:
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
def get_dummy_inputs(self):
raise NotImplementedError(
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
)
raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.")
def test_from_save_pretrained(self, expected_max_diff=5e-5):
torch.manual_seed(0)
@@ -142,9 +169,9 @@ class ModelTesterMixin:
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert (
param_1.shape == param_2.shape
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
assert param_1.shape == param_2.shape, (
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
with torch.no_grad():
image = model(**self.get_dummy_inputs())
@@ -158,9 +185,9 @@ class ModelTesterMixin:
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert (
max_diff <= expected_max_diff
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
assert max_diff <= expected_max_diff, (
f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
model = self.model_class(**self.get_init_dict())
@@ -191,9 +218,9 @@ class ModelTesterMixin:
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert (
max_diff <= expected_max_diff
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
assert max_diff <= expected_max_diff, (
f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
)
def test_from_save_pretrained_dtype(self):
model = self.model_class(**self.get_init_dict())
@@ -242,9 +269,9 @@ class ModelTesterMixin:
second_filtered = second_flat[mask]
max_diff = torch.abs(first_filtered - second_filtered).max().item()
assert (
max_diff <= expected_max_diff
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
assert max_diff <= expected_max_diff, (
f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
)
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
@@ -259,9 +286,9 @@ class ModelTesterMixin:
output = output.to_tuple()[0]
assert output is not None, "Model output is None"
assert (
output[0].shape == expected_output_shape or self.output_shape
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
assert output[0].shape == expected_output_shape or self.output_shape, (
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
)
def test_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
@@ -302,6 +329,71 @@ class ModelTesterMixin:
recursive_check(outputs_tuple, outputs_dict)
def test_getattr_is_correct(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert cap_logger.out == ""
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert cap_logger.out == ""
# warning should be thrown for config attributes accessed directly
with pytest.warns(FutureWarning):
assert model.test_attribute == 5
with pytest.warns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with pytest.raises(AttributeError) as error:
model.does_not_exist
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
@@ -324,12 +416,12 @@ class ModelTesterMixin:
assert param.data.dtype == torch_dtype
with torch.no_grad():
output = model(**get_dummy_inputs())
output_loaded = model_loaded(**get_dummy_inputs())
output = model(**self.get_dummy_inputs())
output_loaded = model_loaded(**self.get_dummy_inputs())
assert torch.allclose(
output, output_loaded, atol=1e-4
), f"Loaded model output differs for {torch_dtype}"
assert torch.allclose(output, output_loaded, atol=1e-4), (
f"Loaded model output differs for {torch_dtype}"
)
@require_accelerator
def test_sharded_checkpoints(self):
@@ -350,9 +442,9 @@ class ModelTesterMixin:
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
@@ -361,9 +453,9 @@ class ModelTesterMixin:
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match after sharded save/load"
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match after sharded save/load"
)
@require_accelerator
def test_sharded_checkpoints_with_variant(self):
@@ -382,16 +474,16 @@ class ModelTesterMixin:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(
os.path.join(tmp_dir, index_filename)
), f"Variant index file {index_filename} should exist"
assert os.path.exists(os.path.join(tmp_dir, index_filename)), (
f"Variant index file {index_filename} should exist"
)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
new_model = new_model.to(torch_device)
@@ -400,11 +492,10 @@ class ModelTesterMixin:
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match after variant sharded save/load"
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match after variant sharded save/load"
)
@require_accelerator
def test_sharded_checkpoints_with_parallel_loading(self):
import time
@@ -433,9 +524,9 @@ class ModelTesterMixin:
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
@@ -459,16 +550,14 @@ class ModelTesterMixin:
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel)
assert torch.allclose(
base_output[0], output_parallel[0], atol=1e-5
), "Output should match with parallel loading"
assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), (
"Output should match with parallel loading"
)
# Verify parallel loading is faster or at least not significantly slower
# For small test models, the difference might be negligible or even slightly slower due to overhead
# so we just check that parallel loading completed successfully and outputs match
assert (
parallel_load_time < sequential_load_time
), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
assert parallel_load_time < sequential_load_time, (
f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
)
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
@@ -506,6 +595,6 @@ class ModelTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with model parallelism"
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match with model parallelism"
)

View File

@@ -18,7 +18,7 @@ import uuid
import pytest
import torch
from huggingface_hub.utils import is_jinja_available
from huggingface_hub.utils import ModelCard, delete_repo, is_jinja_available
from ...others.test_utils import TOKEN, USER, is_staging_test
@@ -58,9 +58,9 @@ class ModelPushToHubTesterMixin:
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(
p1, p2
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
assert torch.equal(p1, p2), (
"Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
)
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
@@ -84,9 +84,9 @@ class ModelPushToHubTesterMixin:
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(
p1, p2
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
assert torch.equal(p1, p2), (
"Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
)
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
@@ -101,9 +101,9 @@ class ModelPushToHubTesterMixin:
model.push_to_hub(self.repo_id, token=TOKEN)
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert (
model_card.library_name == "diffusers"
), f"Expected library_name 'diffusers', got {model_card.library_name}"
assert model_card.library_name == "diffusers", (
f"Expected library_name 'diffusers', got {model_card.library_name}"
)
# Reset repo
delete_repo(self.repo_id, token=TOKEN)

View File

@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import pytest
import torch
from diffusers.models.attention_processor import IPAdapterAttnProcessor
@@ -61,7 +60,7 @@ def create_ip_adapter_state_dict(model):
return {"ip_adapter": ip_state_dict}
def check_if_ip_adapter_correctly_set(model) -> bool:
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
@@ -72,7 +71,7 @@ def check_if_ip_adapter_correctly_set(model) -> bool:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, IPAdapterAttnProcessor):
if isinstance(module, processor_cls):
return True
return False
@@ -93,48 +92,49 @@ class IPAdapterTesterMixin:
Use `pytest -m "not ip_adapter"` to skip these tests
"""
ip_adapter_processor_cls = None
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
self.prepare_model(model)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
# Create dummy IP adapter state dict
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
# Load IP adapter
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly"
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "cross_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict_with_adapter = inputs_dict.copy()
inputs_dict_with_adapter["image_embeds"] = image_embeds
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter processors not set correctly"
)
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(
output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4
), "Output should differ with IP Adapter enabled"
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
"Output should differ with IP Adapter enabled"
)
@pytest.mark.skip(
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
)
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# self.prepare_model(model)
# Create and load dummy IP adapter state dict
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
@@ -146,14 +146,16 @@ class IPAdapterTesterMixin:
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(
output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4
), "Output should differ with different IP Adapter scales"
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
"Output should differ with different IP Adapter scales"
)
@pytest.mark.skip(
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
)
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
self.prepare_model(model)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
@@ -161,49 +163,16 @@ class IPAdapterTesterMixin:
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter should be unloaded"
)
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"
def test_ip_adapter_save_load(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
self.prepare_model(model)
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
torch.manual_seed(0)
output_before_save = model(**inputs_dict, return_dict=False)[0]
with tempfile.TemporaryDirectory() as tmpdir:
# Save the IP adapter weights
save_path = os.path.join(tmpdir, "ip_adapter.safetensors")
import safetensors.torch
safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path)
# Unload and reload
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
# Reload from saved file
loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)}
model._load_ip_adapter_weights([loaded_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded"
torch.manual_seed(0)
output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should match before and after save/load
assert torch.allclose(
output_before_save, output_after_load, atol=1e-4, rtol=1e-4
), "Output should match before and after save/load"

View File

@@ -91,15 +91,15 @@ class LoraTesterMixin:
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(
output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4
), "Output should differ with LoRA enabled"
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4), (
"Output should differ with LoRA enabled"
)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
assert os.path.isfile(
os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
), "LoRA weights file not created"
assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
@@ -119,12 +119,12 @@ class LoraTesterMixin:
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(
output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
), "Output should differ with LoRA enabled"
assert torch.allclose(
outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
), "Outputs should match before and after save/load"
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
"Output should differ with LoRA enabled"
)
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
"Outputs should match before and after save/load"
)
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig

View File

@@ -122,9 +122,9 @@ class CPUOffloadTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with CPU offloading"
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match with CPU offloading"
)
@require_offload_support
def test_disk_offload_without_safetensors(self):
@@ -183,9 +183,9 @@ class CPUOffloadTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with disk offloading (safetensors)"
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match with disk offloading (safetensors)"
)
@is_group_offload
@@ -247,18 +247,18 @@ class GroupOffloadTesterMixin:
)
output_with_group_offloading4 = run_forward(model)
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading1, atol=1e-5
), "Output should match with block-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading2, atol=1e-5
), "Output should match with non-blocking block-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading3, atol=1e-5
), "Output should match with leaf-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading4, atol=1e-5
), "Output should match with leaf-level offloading with stream"
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), (
"Output should match with block-level offloading"
)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), (
"Output should match with non-blocking block-level offloading"
)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), (
"Output should match with leaf-level offloading"
)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), (
"Output should match with leaf-level offloading with stream"
)
@require_group_offload_support
@torch.no_grad()
@@ -345,9 +345,9 @@ class GroupOffloadTesterMixin:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading, atol=atol
), "Output should match with disk-based group offloading"
assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), (
"Output should match with disk-based group offloading"
)
class LayerwiseCastingTesterMixin:
@@ -396,16 +396,16 @@ class LayerwiseCastingTesterMixin:
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert (
fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
), "Memory footprint should decrease with lower precision storage"
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
"Memory footprint should decrease with lower precision storage"
)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert (
fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
), "Peak memory should be lower with bf16 compute on newer GPUs"
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
"Peak memory should be lower with bf16 compute on newer GPUs"
)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
@@ -415,6 +415,36 @@ class LayerwiseCastingTesterMixin:
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
model = self.model_class(**self.get_init_dict())
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = self.get_inputs_dict()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
@is_memory
@require_accelerator

View File

@@ -26,6 +26,7 @@ from diffusers.utils.import_utils import (
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -128,9 +129,9 @@ class QuantizationTesterMixin:
model_quantized = self._create_quantized_model(config_kwargs)
num_params_quantized = model_quantized.num_parameters()
assert (
num_params == num_params_quantized
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
assert num_params == num_params_quantized, (
f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
)
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
model = self._load_unquantized_model()
@@ -140,9 +141,9 @@ class QuantizationTesterMixin:
mem_quantized = model_quantized.get_memory_footprint()
ratio = mem / mem_quantized
assert (
ratio >= expected_memory_reduction
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
assert ratio >= expected_memory_reduction, (
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
)
def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
@@ -243,12 +244,12 @@ class QuantizationTesterMixin:
self._verify_if_layer_quantized(name, module, config_kwargs)
num_quantized_layers += 1
assert (
num_quantized_layers > 0
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
assert (
num_quantized_layers == expected_quantized_layers
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
assert num_quantized_layers > 0, (
f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
)
assert num_quantized_layers == expected_quantized_layers, (
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
@@ -272,9 +273,9 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(
module
), f"Module {name} should not be quantized but was found to be quantized"
assert not self._is_module_quantized(module), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
@@ -296,9 +297,9 @@ class QuantizationTesterMixin:
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
assert (
mem_with_exclusion > mem_fully_quantized
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
assert mem_with_exclusion > mem_fully_quantized, (
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
)
def _test_quantization_device_map(self, config_kwargs):
"""
@@ -380,9 +381,9 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin):
def _verify_if_layer_quantized(self, name, module, config_kwargs):
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
assert (
module.weight.__class__ == expected_weight_class
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
assert module.weight.__class__ == expected_weight_class, (
f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
)
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_num_parameters(self, config_name):
@@ -450,13 +451,13 @@ class BitsAndBytesTesterMixin(QuantizationTesterMixin):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert (
module.weight.dtype == torch.float32
), f"Module {name} should be FP32 but is {module.weight.dtype}"
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert (
module.weight.dtype == torch.uint8
), f"Module {name} should be uint8 but is {module.weight.dtype}"
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
with torch.no_grad():
inputs = self.get_dummy_inputs()

View File

@@ -192,9 +192,9 @@ class SingleFileTesterMixin:
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
assert model.config[param_name] == param_value, (
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_loading_with_diffusers_config_local_files_only(self):
single_file_kwargs = {}

View File

@@ -116,8 +116,7 @@ class TrainingTesterMixin:
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} "
f"do not match expected set {expected_set}"
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
@@ -169,9 +168,9 @@ class TrainingTesterMixin:
loss_2.backward()
# compare the output and parameters gradients
assert (
loss - loss_2
).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
assert (loss - loss_2).abs() < loss_tolerance, (
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
@@ -184,9 +183,9 @@ class TrainingTesterMixin:
if param.grad is None:
continue
assert torch_all_close(
param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol
), f"Gradient mismatch for {name}"
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
f"Gradient mismatch for {name}"
)
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()

View File

@@ -13,119 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Any
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerTesterConfig:
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
pretrained_model_kwargs = {"subfolder": "transformer"}
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
"""Return Flux model initialization arguments."""
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -137,11 +68,40 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
"axes_dims_rope": [4, 4, 8],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator
),
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
@@ -162,63 +122,223 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
pass
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
pass
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
pass
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""
pass
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
ip_adapter_processor_cls = FluxIPAdapterAttnProcessor
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
pass
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
pass
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}

View File

@@ -1,330 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BitsAndBytesTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class FluxTransformerTesterConfig:
model_class = FluxTransformer2DModel
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
pretrained_model_kwargs = {"subfolder": "transformer"}
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
"""Return Flux model initialization arguments."""
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator
),
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
pass
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
pass
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
pass
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
def prepare_model(self, model):
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
model.set_attn_processor(FluxIPAdapterAttnProcessor(hidden_size, joint_attention_dim, scale=1.0))
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
pass
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
pass
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}

View File

@@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(
output_native, output_cuda, 1e-2
), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly

View File

@@ -406,6 +406,15 @@ def is_modelopt(test_case):
return pytest.mark.modelopt(test_case)
def is_context_parallel(test_case):
"""
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
pytest -m "not context_parallel" to skip
pytest -m context_parallel to run only these tests
"""
return pytest.mark.context_parallel(test_case)
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.