1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-10-17 10:43:52 +05:30
parent 1ac55e7a7e
commit 74160ed00f
8 changed files with 18 additions and 43 deletions

View File

@@ -36,12 +36,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AsymmetricAutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2

View File

@@ -18,12 +18,13 @@ from diffusers import AutoencoderKLCosmos
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLCosmosTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, unittest.TestCase):
@unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self):
pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -23,12 +23,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
@@ -81,7 +82,3 @@ class AutoencoderDCTests(ModelTesterMixin, unittest.TestCase):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -23,12 +23,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -19,12 +19,13 @@ from diffusers import AutoencoderKLMagvit
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLMagvitTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2

View File

@@ -17,18 +17,15 @@ import unittest
from diffusers import AutoencoderKLMochi
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLMochiTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi
main_input_name = "sample"
base_precision = 1e-2
@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, unittest.TestCase):
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
"""
pass
@unittest.skip("Unsupported test.")
def test_model_parallelism(self):
"""

View File

@@ -31,12 +31,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderOobleckTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, unittest.TestCase):
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self):
return

View File

@@ -19,19 +19,15 @@ import torch
from diffusers import VQModel
from ...testing_utils import (
backend_manual_seed,
enable_full_determinism,
floats_tensor,
torch_device,
)
from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class VQModelTests(ModelTesterMixin, unittest.TestCase):
class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = VQModel
main_input_name = "sample"