mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -21,7 +21,7 @@ import torch
|
||||
from diffusers import ZImageTransformer2DModel
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
@@ -41,8 +41,7 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def prepare_dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
height = width = embedding_dim = 16
|
||||
@@ -56,6 +55,10 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
@@ -115,3 +118,22 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Group offloading needs to revisited for this model because of state population.")
|
||||
def test_group_offloading_with_disk(self):
|
||||
super().test_group_offloading_with_disk()
|
||||
|
||||
|
||||
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = ZImageTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@unittest.skip("Fullgraph is broken")
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
@unittest.skip("Fullgraph AoT is broken")
|
||||
def test_compile_works_with_aot(self):
|
||||
super().test_compile_works_with_aot()
|
||||
|
||||
Reference in New Issue
Block a user