From 8ee24fcdaaf23e8a62b4e2d749d6c15dadae10e0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 29 Nov 2025 08:37:08 +0530 Subject: [PATCH] up --- .../test_models_transformer_z_image.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_z_image.py b/tests/models/transformers/test_models_transformer_z_image.py index 35af2c3bfb..cae1173a72 100644 --- a/tests/models/transformers/test_models_transformer_z_image.py +++ b/tests/models/transformers/test_models_transformer_z_image.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import unittest @@ -87,6 +88,25 @@ class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): inputs_dict = self.dummy_input return init_dict, inputs_dict + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + def test_gradient_checkpointing_is_applied(self): expected_set = {"ZImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set)