From a1f36ee3ef4ae1bf98bd260e539197259aa981c1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 3 Dec 2025 22:05:46 +0800 Subject: [PATCH] [Z-Image] various small changes, Z-Image transformer tests, etc. (#12741) * start zimage model tests. * up * up * up * up * up * up * up * up * up * up * up * up * Revert "up" This reverts commit bca3e27c96b942db49ccab8ddf824e7a54d43ed1. * expand upon compilation failure reason. * Update tests/models/transformers/test_models_transformer_z_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * reinitialize the padding tokens to ones to prevent NaN problems. * updates * up * skipping ZImage DiT tests * up * up --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../transformers/transformer_z_image.py | 57 ++---- .../pipelines/z_image/pipeline_z_image.py | 4 +- tests/lora/test_lora_layers_z_image.py | 146 +++++++++++++-- tests/models/test_modeling_common.py | 111 ++++++++++-- .../test_models_transformer_z_image.py | 171 ++++++++++++++++++ tests/pipelines/z_image/test_z_image.py | 15 +- 6 files changed, 424 insertions(+), 80 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_z_image.py diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 097672e0f7..1459e5974e 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -27,6 +27,7 @@ from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput ADALN_EMBED_DIM = 256 @@ -39,17 +40,9 @@ class TimestepEmbedder(nn.Module): if mid_size is None: mid_size = out_size self.mlp = nn.Sequential( - nn.Linear( - frequency_embedding_size, - mid_size, - bias=True, - ), + nn.Linear(frequency_embedding_size, mid_size, bias=True), nn.SiLU(), - nn.Linear( - mid_size, - out_size, - bias=True, - ), + nn.Linear(mid_size, out_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @@ -211,9 +204,7 @@ class ZImageTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), - ) + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) def forward( self, @@ -230,33 +221,19 @@ class ZImageTransformerBlock(nn.Module): # Attention block attn_out = self.attention( - self.attention_norm1(x) * scale_msa, - attention_mask=attn_mask, - freqs_cis=freqs_cis, + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis ) x = x + gate_msa * self.attention_norm2(attn_out) # FFN block - x = x + gate_mlp * self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x) * scale_mlp, - ) - ) + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) else: # Attention block - attn_out = self.attention( - self.attention_norm1(x), - attention_mask=attn_mask, - freqs_cis=freqs_cis, - ) + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) x = x + self.attention_norm2(attn_out) # FFN block - x = x + self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x), - ) - ) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) return x @@ -404,10 +381,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ] ) self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), - ) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -494,11 +468,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ) # padded feature - cap_padded_feat = torch.cat( - [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], - dim=0, - ) - all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat) + cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) + all_cap_feats_out.append(cap_padded_feat) ### Process Image C, F, H, W = image.size() @@ -564,6 +535,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, + return_dict: bool = True, ): assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size @@ -672,4 +644,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr unified = list(unified.unbind(dim=0)) x = self.unpatchify(unified, x_size, patch_size, f_patch_size) - return x, {} + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index ee7473d3cd..82bdd7d361 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -525,9 +525,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix latent_model_input_list = list(latent_model_input.unbind(dim=0)) model_out_list = self.transformer( - latent_model_input_list, - timestep_model_input, - prompt_embeds_model_input, + latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False )[0] if apply_cfg: diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py index fcaf37b88c..35d1389d96 100644 --- a/tests/lora/test_lora_layers_z_image.py +++ b/tests/lora/test_lora_layers_z_image.py @@ -15,17 +15,13 @@ import sys import unittest +import numpy as np import torch from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model -from diffusers import ( - AutoencoderKL, - FlowMatchEulerDiscreteScheduler, - ZImagePipeline, - ZImageTransformer2DModel, -) +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel -from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend +from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device if is_peft_available(): @@ -34,13 +30,9 @@ if is_peft_available(): sys.path.append(".") -from .utils import PeftLoraLoaderMixinTests # noqa: E402 +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 -@unittest.skip( - "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations " - "and torch.empty padding tokens. LoRA functionality works correctly with real models." -) @require_peft_backend class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = ZImagePipeline @@ -127,6 +119,12 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id) transformer = self.transformer_cls(**self.transformer_kwargs) + # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`. + # This can cause NaN data values in our testing environment. Fixating them + # helps prevent that issue. + with torch.no_grad(): + transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data)) + transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data)) vae = self.vae_cls(**self.vae_kwargs) if scheduler_cls is None: @@ -161,3 +159,127 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } return pipeline_components, text_lora_config, denoiser_lora_config + + def test_correct_lora_configs_with_different_ranks(self): + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.transformer.delete_adapters("adapter-1") + + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer + for name, _ in denoiser.named_modules(): + if "to_k" in name and "attention" in name and "lora" not in name: + module_name_to_rank_update = name.replace(".base_layer.", ".") + break + + # change the rank_pattern + updated_rank = denoiser_lora_config.r * 2 + denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern + + self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) + + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + pipe.transformer.delete_adapters("adapter-1") + + # similarly change the alpha_pattern + updated_alpha = denoiser_lora_config.lora_alpha * 2 + denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue( + pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} + ) + + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + @skip_mps + def test_lora_fuse_nan(self): + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + possible_tower_names = ["noise_refiner"] + filtered_tower_names = [ + tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) + ] + for tower_name in filtered_tower_names: + transformer_tower = getattr(pipe.transformer, tower_name) + transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + out = pipe(**inputs)[0] + + self.assertTrue(np.isnan(out).all()) + + def test_lora_scale_kwargs_match_fusion(self): + super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2) + + @unittest.skip("Needs to be debugged.") + def test_set_adapters_match_attention_kwargs(self): + super().test_set_adapters_match_attention_kwargs() + + @unittest.skip("Needs to be debugged.") + def test_simple_inference_with_text_denoiser_lora_and_scale(self): + super().test_simple_inference_with_text_denoiser_lora_and_scale() + + @unittest.skip("Not supported in ZImage.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in ZImage.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in ZImage.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in ZImage.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in ZImage.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in ZImage.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in ZImage.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in ZImage.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 034a1add63..ad5a6ba480 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -47,6 +47,7 @@ from diffusers.models.attention_processor import ( XFormersAttnProcessor, ) from diffusers.models.auto_model import AutoModel +from diffusers.models.modeling_outputs import BaseOutput from diffusers.training_utils import EMAModel from diffusers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -108,6 +109,11 @@ def check_if_lora_correctly_set(model) -> bool: return False +def normalize_output(out): + out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out + return torch.stack(out0) if isinstance(out0, list) else out0 + + # Will be run via run_test_in_subprocess def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): error = None @@ -536,6 +542,9 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] + image = normalize_output(image) + new_image = normalize_output(new_image) + max_diff = (image - new_image).abs().max().item() self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") @@ -780,6 +789,9 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] + image = normalize_output(image) + new_image = normalize_output(new_image) + max_diff = (image - new_image).abs().max().item() self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") @@ -842,6 +854,9 @@ class ModelTesterMixin: if isinstance(second, dict): second = second.to_tuple()[0] + first = normalize_output(first) + second = normalize_output(second) + out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)] @@ -860,11 +875,15 @@ class ModelTesterMixin: if isinstance(output, dict): output = output.to_tuple()[0] + if isinstance(output, list): + output = torch.stack(output) self.assertIsNotNone(output) # input & output have to have the same shape input_tensor = inputs_dict[self.main_input_name] + if isinstance(input_tensor, list): + input_tensor = torch.stack(input_tensor) if expected_output_shape is None: expected_shape = input_tensor.shape @@ -898,11 +917,15 @@ class ModelTesterMixin: if isinstance(output_1, dict): output_1 = output_1.to_tuple()[0] + if isinstance(output_1, list): + output_1 = torch.stack(output_1) output_2 = new_model(**inputs_dict) if isinstance(output_2, dict): output_2 = output_2.to_tuple()[0] + if isinstance(output_2, list): + output_2 = torch.stack(output_2) self.assertEqual(output_1.shape, output_2.shape) @@ -1138,6 +1161,8 @@ class ModelTesterMixin: torch.manual_seed(0) output_no_lora = model(**inputs_dict, return_dict=False)[0] + if isinstance(output_no_lora, list): + output_no_lora = torch.stack(output_no_lora) denoiser_lora_config = LoraConfig( r=rank, @@ -1151,6 +1176,8 @@ class ModelTesterMixin: torch.manual_seed(0) outputs_with_lora = model(**inputs_dict, return_dict=False)[0] + if isinstance(outputs_with_lora, list): + outputs_with_lora = torch.stack(outputs_with_lora) self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) @@ -1175,6 +1202,8 @@ class ModelTesterMixin: torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] + if isinstance(outputs_with_lora_2, list): + outputs_with_lora_2 = torch.stack(outputs_with_lora_2) self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) @@ -1296,31 +1325,35 @@ class ModelTesterMixin: def test_cpu_offload(self): if self.model_class._no_split_modules is None: pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() - model = model.to(torch_device) torch.manual_seed(0) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_sizes(model)[""] - # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] + with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) for max_size in max_gpu_sizes: max_memory = {0: max_size, "cpu": model_size * 2} new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) new_output = new_model(**inputs_dict) + new_normalized_output = normalize_output(new_output) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) @require_torch_accelerator def test_disk_offload_without_safetensors(self): @@ -1333,6 +1366,7 @@ class ModelTesterMixin: torch.manual_seed(0) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_sizes(model)[""] max_size = int(self.model_split_percents[0] * model_size) @@ -1352,8 +1386,8 @@ class ModelTesterMixin: self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) new_output = new_model(**inputs_dict) - - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + new_normalized_output = normalize_output(new_output) + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) @require_torch_accelerator def test_disk_offload_with_safetensors(self): @@ -1366,6 +1400,7 @@ class ModelTesterMixin: torch.manual_seed(0) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_sizes(model)[""] with tempfile.TemporaryDirectory() as tmp_dir: @@ -1380,8 +1415,9 @@ class ModelTesterMixin: self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) new_output = new_model(**inputs_dict) + new_normalized_output = normalize_output(new_output) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) @require_torch_multi_accelerator def test_model_parallelism(self): @@ -1422,6 +1458,7 @@ class ModelTesterMixin: model = model.to(torch_device) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. @@ -1443,8 +1480,9 @@ class ModelTesterMixin: if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) + new_normalized_output = normalize_output(new_output) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) @require_torch_accelerator def test_sharded_checkpoints_with_variant(self): @@ -1454,6 +1492,7 @@ class ModelTesterMixin: model = model.to(torch_device) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. @@ -1481,8 +1520,9 @@ class ModelTesterMixin: if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) + new_normalized_output = normalize_output(new_output) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) @require_torch_accelerator def test_sharded_checkpoints_with_parallel_loading(self): @@ -1492,6 +1532,7 @@ class ModelTesterMixin: model = model.to(torch_device) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. @@ -1515,7 +1556,9 @@ class ModelTesterMixin: if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + new_normalized_output = normalize_output(new_output) + + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) # set to no. os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no" @@ -1529,6 +1572,7 @@ class ModelTesterMixin: torch.manual_seed(0) base_output = model(**inputs_dict) + base_normalized_output = normalize_output(base_output) model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. @@ -1549,7 +1593,9 @@ class ModelTesterMixin: if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + new_normalized_output = normalize_output(new_output) + + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) # This test is okay without a GPU because we're not running any execution. We're just serializing # and check if the resultant files are following an expected format. @@ -1629,7 +1675,9 @@ class ModelTesterMixin: model = self.model_class(**config) model.eval() model.to(torch_device) - base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy() + base_slice = model(**inputs_dict)[0] + base_slice = normalize_output(base_slice) + base_slice = base_slice.detach().flatten().cpu().numpy() def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN @@ -1655,7 +1703,9 @@ class ModelTesterMixin: model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(model, storage_dtype, compute_dtype) - output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() + output = model(**inputs_dict)[0] + output = normalize_output(output) + output = output.float().flatten().detach().cpu().numpy() # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. # We just want to make sure that the layerwise casting is working as expected. @@ -1716,6 +1766,12 @@ class ModelTesterMixin: @parameterized.expand([False, True]) @require_torch_accelerator def test_group_offloading(self, record_stream): + for cls in inspect.getmro(self.__class__): + if "test_group_offloading" in cls.__dict__ and cls is not ModelTesterMixin: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + pytest.skip("Model does not support group offloading.") + if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") @@ -1738,21 +1794,25 @@ class ModelTesterMixin: model.to(torch_device) output_without_group_offloading = run_forward(model) + output_without_group_offloading = normalize_output(output_without_group_offloading) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) output_with_group_offloading1 = run_forward(model) + output_with_group_offloading1 = normalize_output(output_with_group_offloading1) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) output_with_group_offloading2 = run_forward(model) + output_with_group_offloading2 = normalize_output(output_with_group_offloading2) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="leaf_level") output_with_group_offloading3 = run_forward(model) + output_with_group_offloading3 = normalize_output(output_with_group_offloading3) torch.manual_seed(0) model = self.model_class(**init_dict) @@ -1760,6 +1820,7 @@ class ModelTesterMixin: torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream ) output_with_group_offloading4 = run_forward(model) + output_with_group_offloading4 = normalize_output(output_with_group_offloading4) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) @@ -1799,6 +1860,12 @@ class ModelTesterMixin: @torch.no_grad() @torch.inference_mode() def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5): + for cls in inspect.getmro(self.__class__): + if "test_group_offloading_with_disk" in cls.__dict__ and cls is not ModelTesterMixin: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + pytest.skip("Model does not support group offloading with disk yet.") + if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") @@ -1821,6 +1888,7 @@ class ModelTesterMixin: model.eval() model.to(torch_device) output_without_group_offloading = _run_forward(model, inputs_dict) + output_without_group_offloading = normalize_output(output_without_group_offloading) torch.manual_seed(0) model = self.model_class(**init_dict) @@ -1856,6 +1924,7 @@ class ModelTesterMixin: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) + output_with_group_offloading = normalize_output(output_with_group_offloading) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)) def test_auto_model(self, expected_max_diff=5e-5): @@ -1889,10 +1958,17 @@ class ModelTesterMixin: output_original = model(**inputs_dict) output_auto = auto_model(**inputs_dict) - if isinstance(output_original, dict): - output_original = output_original.to_tuple()[0] - if isinstance(output_auto, dict): - output_auto = output_auto.to_tuple()[0] + if isinstance(output_original, dict): + output_original = output_original.to_tuple()[0] + if isinstance(output_auto, dict): + output_auto = output_auto.to_tuple()[0] + + if isinstance(output_original, list): + output_original = torch.stack(output_original) + if isinstance(output_auto, list): + output_auto = torch.stack(output_auto) + + output_original, output_auto = output_original.float(), output_auto.float() max_diff = (output_original - output_auto).abs().max().item() self.assertLessEqual( @@ -2083,6 +2159,8 @@ class TorchCompileTesterMixin: recompile_limit = 1 if self.model_class.__name__ == "UNet2DConditionModel": recompile_limit = 2 + elif self.model_class.__name__ == "ZImageTransformer2DModel": + recompile_limit = 3 with ( torch._inductor.utils.fresh_inductor_cache(), @@ -2184,7 +2262,6 @@ class LoraHotSwappingForModelTesterMixin: backend_empty_cache(torch_device) def get_lora_config(self, lora_rank, lora_alpha, target_modules): - # from diffusers test_models_unet_2d_condition.py from peft import LoraConfig lora_config = LoraConfig( diff --git a/tests/models/transformers/test_models_transformer_z_image.py b/tests/models/transformers/test_models_transformer_z_image.py new file mode 100644 index 0000000000..79054019f2 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_z_image.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import unittest + +import torch + +from diffusers import ZImageTransformer2DModel + +from ...testing_utils import IS_GITHUB_ACTIONS, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + + +@unittest.skipIf( + IS_GITHUB_ACTIONS, + reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.", +) +class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = ZImageTransformer2DModel + main_input_name = "x" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.9, 0.9, 0.9] + + def prepare_dummy_input(self, height=16, width=16): + batch_size = 1 + num_channels = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)] + encoder_hidden_states = [ + torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size) + ] + timestep = torch.tensor([0.0]).to(torch_device) + + return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep} + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "all_patch_size": (2,), + "all_f_patch_size": (1,), + "in_channels": 16, + "dim": 16, + "n_layers": 1, + "n_refiner_layers": 1, + "n_heads": 1, + "n_kv_heads": 2, + "qk_norm": True, + "cap_feat_dim": 16, + "rope_theta": 256.0, + "t_scale": 1000.0, + "axes_dims": [8, 4, 4], + "axes_lens": [256, 32, 32], + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"ZImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_training(self): + super().test_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_ema_training(self): + super().test_ema_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing() + + @unittest.skip( + "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices." + ) + def test_layerwise_casting_training(self): + super().test_layerwise_casting_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") + def test_group_offloading(self): + super().test_group_offloading() + + @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") + def test_group_offloading_with_disk(self): + super().test_group_offloading_with_disk() + + +class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = ZImageTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return ZImageTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return ZImageTransformerTests().prepare_dummy_input(height=height, width=width) + + @unittest.skip( + "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice." + ) + def test_torch_compile_recompilation_and_graph_break(self): + super().test_torch_compile_recompilation_and_graph_break() + + @unittest.skip("Fullgraph AoT is broken") + def test_compile_works_with_aot(self): + super().test_compile_works_with_aot() + + @unittest.skip("Fullgraph is broken") + def test_compile_on_different_shapes(self): + super().test_compile_on_different_shapes() diff --git a/tests/pipelines/z_image/test_z_image.py b/tests/pipelines/z_image/test_z_image.py index 709473b0db..79a5fa0de5 100644 --- a/tests/pipelines/z_image/test_z_image.py +++ b/tests/pipelines/z_image/test_z_image.py @@ -20,12 +20,7 @@ import numpy as np import torch from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model -from diffusers import ( - AutoencoderKL, - FlowMatchEulerDiscreteScheduler, - ZImagePipeline, - ZImageTransformer2DModel, -) +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel from ...testing_utils import torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -106,6 +101,12 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): axes_dims=[8, 4, 4], axes_lens=[256, 32, 32], ) + # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`. + # This can cause NaN data values in our testing environment. Fixating them + # helps prevent that issue. + with torch.no_grad(): + transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data)) + transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data)) torch.manual_seed(0) vae = AutoencoderKL( @@ -183,7 +184,7 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): self.assertEqual(generated_image.shape, (3, 32, 32)) # fmt: off - expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732]) + expected_slice = torch.tensor([0.4622, 0.4532, 0.4714, 0.5087, 0.5371, 0.5405, 0.4492, 0.4479, 0.2984, 0.2783, 0.5409, 0.6577, 0.3952, 0.5524, 0.5262, 0.453]) # fmt: on generated_slice = generated_image.flatten()