1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-11-28 18:55:47 +05:30
parent 12608de5cb
commit 6d47d106ba

View File

@@ -21,7 +21,7 @@ import torch
from diffusers import ZImageTransformer2DModel
from ...testing_utils import torch_device
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
@@ -41,8 +41,7 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.8, 0.9]
@property
def dummy_input(self):
def prepare_dummy_input(self):
batch_size = 1
num_channels = 16
height = width = embedding_dim = 16
@@ -56,6 +55,10 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (4, 32, 32)
@@ -115,3 +118,22 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
@unittest.skip("Group offloading needs to revisited for this model because of state population.")
def test_group_offloading_with_disk(self):
super().test_group_offloading_with_disk()
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = ZImageTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
@unittest.skip("Fullgraph is broken")
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
@unittest.skip("Fullgraph AoT is broken")
def test_compile_works_with_aot(self):
super().test_compile_works_with_aot()