From 6d47d106ba630e2fea4ece1be27fb7156119fef8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 28 Nov 2025 18:55:47 +0530 Subject: [PATCH] up --- .../test_models_transformer_z_image.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_z_image.py b/tests/models/transformers/test_models_transformer_z_image.py index 61687977e1..adc1b85747 100644 --- a/tests/models/transformers/test_models_transformer_z_image.py +++ b/tests/models/transformers/test_models_transformer_z_image.py @@ -21,7 +21,7 @@ import torch from diffusers import ZImageTransformer2DModel from ...testing_utils import torch_device -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin # Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations @@ -41,8 +41,7 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): # We override the items here because the transformer under consideration is small. model_split_percents = [0.8, 0.8, 0.9] - @property - def dummy_input(self): + def prepare_dummy_input(self): batch_size = 1 num_channels = 16 height = width = embedding_dim = 16 @@ -56,6 +55,10 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep} + @property + def dummy_input(self): + return self.prepare_dummy_input() + @property def input_shape(self): return (4, 32, 32) @@ -115,3 +118,22 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): @unittest.skip("Group offloading needs to revisited for this model because of state population.") def test_group_offloading_with_disk(self): super().test_group_offloading_with_disk() + + +class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = ZImageTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return ZImageTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return ZImageTransformerTests().prepare_dummy_input(height=height, width=width) + + @unittest.skip("Fullgraph is broken") + def test_torch_compile_recompilation_and_graph_break(self): + super().test_torch_compile_recompilation_and_graph_break() + + @unittest.skip("Fullgraph AoT is broken") + def test_compile_works_with_aot(self): + super().test_compile_works_with_aot()