diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 4f2d56ea8f..f8fe2d1db8 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -27,6 +27,7 @@ from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput ADALN_EMBED_DIM = 256 @@ -39,17 +40,9 @@ class TimestepEmbedder(nn.Module): if mid_size is None: mid_size = out_size self.mlp = nn.Sequential( - nn.Linear( - frequency_embedding_size, - mid_size, - bias=True, - ), + nn.Linear(frequency_embedding_size, mid_size, bias=True), nn.SiLU(), - nn.Linear( - mid_size, - out_size, - bias=True, - ), + nn.Linear(mid_size, out_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @@ -211,9 +204,7 @@ class ZImageTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), - ) + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) def forward( self, @@ -230,33 +221,19 @@ class ZImageTransformerBlock(nn.Module): # Attention block attn_out = self.attention( - self.attention_norm1(x) * scale_msa, - attention_mask=attn_mask, - freqs_cis=freqs_cis, + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis ) x = x + gate_msa * self.attention_norm2(attn_out) # FFN block - x = x + gate_mlp * self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x) * scale_mlp, - ) - ) + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) else: # Attention block - attn_out = self.attention( - self.attention_norm1(x), - attention_mask=attn_mask, - freqs_cis=freqs_cis, - ) + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) x = x + self.attention_norm2(attn_out) # FFN block - x = x + self.ffn_norm2( - self.feed_forward( - self.ffn_norm1(x), - ) - ) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) return x @@ -404,10 +381,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ] ) self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), - ) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -492,10 +466,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ) ) # padded feature - cap_padded_feat = torch.cat( - [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], - dim=0, - ) + cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) all_cap_feats_out.append(cap_padded_feat) ### Process Image @@ -557,6 +528,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, + return_dict: bool = True, ): assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size @@ -658,4 +630,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr unified = list(unified.unbind(dim=0)) x = self.unpatchify(unified, x_size, patch_size, f_patch_size) - return x, {} + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index a4fcacb6eb..1e4fadd753 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -525,9 +525,7 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): latent_model_input_list = list(latent_model_input.unbind(dim=0)) model_out_list = self.transformer( - latent_model_input_list, - timestep_model_input, - prompt_embeds_model_input, + latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False )[0] if apply_cfg: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6f4c3d544b..475824a855 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -536,6 +536,11 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] + if isinstance(image, list): + image = torch.stack(image) + if isinstance(new_image, list): + new_image = torch.stack(new_image) + max_diff = (image - new_image).abs().max().item() self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") @@ -780,6 +785,11 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] + if isinstance(image, list): + image = torch.stack(image) + if isinstance(new_image, list): + new_image = torch.stack(new_image) + max_diff = (image - new_image).abs().max().item() self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") @@ -842,6 +852,11 @@ class ModelTesterMixin: if isinstance(second, dict): second = second.to_tuple()[0] + if isinstance(first, list): + first = torch.stack(first) + if isinstance(second, list): + second = torch.stack(second) + out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)] @@ -860,11 +875,15 @@ class ModelTesterMixin: if isinstance(output, dict): output = output.to_tuple()[0] + if isinstance(output, list): + output = torch.stack(output) self.assertIsNotNone(output) # input & output have to have the same shape input_tensor = inputs_dict[self.main_input_name] + if isinstance(input_tensor, list): + input_tensor = torch.stack(input_tensor) if expected_output_shape is None: expected_shape = input_tensor.shape @@ -898,11 +917,15 @@ class ModelTesterMixin: if isinstance(output_1, dict): output_1 = output_1.to_tuple()[0] + if isinstance(output_1, list): + output_1 = torch.stack(output_1) output_2 = new_model(**inputs_dict) if isinstance(output_2, dict): output_2 = output_2.to_tuple()[0] + if isinstance(output_2, list): + output_2 = torch.stack(output_2) self.assertEqual(output_1.shape, output_2.shape) @@ -1138,6 +1161,8 @@ class ModelTesterMixin: torch.manual_seed(0) output_no_lora = model(**inputs_dict, return_dict=False)[0] + if isinstance(output_no_lora, list): + output_no_lora = torch.stack(output_no_lora) denoiser_lora_config = LoraConfig( r=rank, @@ -1151,6 +1176,8 @@ class ModelTesterMixin: torch.manual_seed(0) outputs_with_lora = model(**inputs_dict, return_dict=False)[0] + if isinstance(outputs_with_lora, list): + outputs_with_lora = torch.stack(outputs_with_lora) self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) @@ -1175,6 +1202,8 @@ class ModelTesterMixin: torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] + if isinstance(outputs_with_lora_2, list): + outputs_with_lora_2 = torch.stack(outputs_with_lora_2) self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) @@ -1307,6 +1336,7 @@ class ModelTesterMixin: model_size = compute_module_sizes(model)[""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] + print(f"{max_gpu_sizes=}") with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) @@ -1314,13 +1344,19 @@ class ModelTesterMixin: max_memory = {0: max_size, "cpu": model_size * 2} new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded + print(f"{max_size=} {new_model.hf_device_map.values()=}") self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) @require_torch_accelerator def test_disk_offload_without_safetensors(self): @@ -1353,7 +1389,12 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) @require_torch_accelerator def test_disk_offload_with_safetensors(self): @@ -1381,7 +1422,12 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) @require_torch_multi_accelerator def test_model_parallelism(self): @@ -1444,7 +1490,12 @@ class ModelTesterMixin: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) @require_torch_accelerator def test_sharded_checkpoints_with_variant(self): @@ -1482,7 +1533,12 @@ class ModelTesterMixin: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) @require_torch_accelerator def test_sharded_checkpoints_with_parallel_loading(self): @@ -1515,7 +1571,13 @@ class ModelTesterMixin: if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) # set to no. os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no" @@ -1549,7 +1611,13 @@ class ModelTesterMixin: if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + + if isinstance(base_output[0], list): + base_output = torch.stack(base_output[0]) + if isinstance(new_output[0], list): + new_output = torch.stack(new_output[0]) + + self.assertTrue(torch.allclose(base_output, new_output, atol=1e-5)) # This test is okay without a GPU because we're not running any execution. We're just serializing # and check if the resultant files are following an expected format. @@ -1629,7 +1697,10 @@ class ModelTesterMixin: model = self.model_class(**config) model.eval() model.to(torch_device) - base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy() + base_slice = model(**inputs_dict)[0] + if isinstance(base_slice, list): + base_slice = torch.stack(base_slice) + base_slice = base_slice.detach().flatten().cpu().numpy() def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN @@ -1655,7 +1726,10 @@ class ModelTesterMixin: model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(model, storage_dtype, compute_dtype) - output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() + output = model(**inputs_dict)[0] + if isinstance(output, list): + output = torch.stack(output) + output = output.float().flatten().detach().cpu().numpy() # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. # We just want to make sure that the layerwise casting is working as expected. @@ -1716,6 +1790,12 @@ class ModelTesterMixin: @parameterized.expand([False, True]) @require_torch_accelerator def test_group_offloading(self, record_stream): + for cls in inspect.getmro(self.__class__): + if "test_group_offloading" in cls.__dict__ and cls is not ModelTesterMixin: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + pytest.skip("Model does not support group offloading.") + if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") @@ -1738,21 +1818,29 @@ class ModelTesterMixin: model.to(torch_device) output_without_group_offloading = run_forward(model) + if isinstance(output_without_group_offloading, list): + output_without_group_offloading = torch.stack(output_without_group_offloading) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) output_with_group_offloading1 = run_forward(model) + if isinstance(output_with_group_offloading1, list): + output_with_group_offloading1 = torch.stack(output_with_group_offloading1) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) output_with_group_offloading2 = run_forward(model) + if isinstance(output_with_group_offloading2, list): + output_with_group_offloading2 = torch.stack(output_with_group_offloading2) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="leaf_level") output_with_group_offloading3 = run_forward(model) + if isinstance(output_with_group_offloading3, list): + output_with_group_offloading3 = torch.stack(output_with_group_offloading3) torch.manual_seed(0) model = self.model_class(**init_dict) @@ -1760,6 +1848,8 @@ class ModelTesterMixin: torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream ) output_with_group_offloading4 = run_forward(model) + if isinstance(output_with_group_offloading4, list): + output_with_group_offloading4 = torch.stack(output_with_group_offloading4) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) @@ -1814,6 +1904,12 @@ class ModelTesterMixin: torch.manual_seed(0) return model(**inputs_dict)[0] + for cls in inspect.getmro(self.__class__): + if "test_group_offloading_with_disk" in cls.__dict__ and cls is not ModelTesterMixin: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + pytest.skip("Model does not support group offloading with disk.") + if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level": pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.") @@ -1824,6 +1920,8 @@ class ModelTesterMixin: model.eval() model.to(torch_device) output_without_group_offloading = _run_forward(model, inputs_dict) + if isinstance(output_without_group_offloading, list): + output_without_group_offloading = torch.stack(output_without_group_offloading) torch.manual_seed(0) model = self.model_class(**init_dict) @@ -1859,6 +1957,8 @@ class ModelTesterMixin: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) + if isinstance(output_with_group_offloading, list): + output_with_group_offloading = torch.stack(output_with_group_offloading) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)) def test_auto_model(self, expected_max_diff=5e-5): @@ -1892,10 +1992,17 @@ class ModelTesterMixin: output_original = model(**inputs_dict) output_auto = auto_model(**inputs_dict) - if isinstance(output_original, dict): - output_original = output_original.to_tuple()[0] - if isinstance(output_auto, dict): - output_auto = output_auto.to_tuple()[0] + if isinstance(output_original, dict): + output_original = output_original.to_tuple()[0] + if isinstance(output_auto, dict): + output_auto = output_auto.to_tuple()[0] + + if isinstance(output_original, list): + output_original = torch.stack(output_original) + if isinstance(output_auto, list): + output_auto = torch.stack(output_auto) + + output_original, output_auto = output_original.float(), output_auto.float() max_diff = (output_original - output_auto).abs().max().item() self.assertLessEqual( diff --git a/tests/models/transformers/test_models_transformer_z_image.py b/tests/models/transformers/test_models_transformer_z_image.py new file mode 100644 index 0000000000..61687977e1 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_z_image.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch + +from diffusers import ZImageTransformer2DModel + +from ...testing_utils import torch_device +from ..test_modeling_common import ModelTesterMixin + + +# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + + +class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = ZImageTransformer2DModel + main_input_name = "x" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.8, 0.8, 0.9] + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 16 + height = width = embedding_dim = 16 + sequence_length = 16 + + hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)] + encoder_hidden_states = [ + torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size) + ] + timestep = torch.tensor([0.0]).to(torch_device) + + return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "all_patch_size": (2,), + "all_f_patch_size": (1,), + "in_channels": 16, + "dim": 32, + "n_layers": 2, + "n_refiner_layers": 1, + "n_heads": 2, + "n_kv_heads": 2, + "qk_norm": True, + "cap_feat_dim": 16, + "rope_theta": 256.0, + "t_scale": 1000.0, + "axes_dims": [8, 4, 4], + "axes_lens": [256, 32, 32], + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"ZImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_training(self): + super().test_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_ema_training(self): + super().test_ema_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing() + + @unittest.skip("Test needs to be revisited.") + def test_layerwise_casting_training(self): + super().test_layerwise_casting_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skip("Group offloading needs to revisited for this model because of state population.") + def test_group_offloading(self): + super().test_group_offloading() + + @unittest.skip("Group offloading needs to revisited for this model because of state population.") + def test_group_offloading_with_disk(self): + super().test_group_offloading_with_disk() diff --git a/tests/pipelines/z_image/test_z_image.py b/tests/pipelines/z_image/test_z_image.py index 709473b0db..ab2206311d 100644 --- a/tests/pipelines/z_image/test_z_image.py +++ b/tests/pipelines/z_image/test_z_image.py @@ -27,7 +27,7 @@ from diffusers import ( ZImageTransformer2DModel, ) -from ...testing_utils import torch_device +from ...testing_utils import is_flaky, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np @@ -169,6 +169,7 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): return inputs + @is_flaky(max_attempts=10) def test_inference(self): device = "cpu"