diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 7cbf5d272a..deae3b0117 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -14,7 +14,6 @@ # limitations under the License. - import random import tempfile import unittest @@ -22,7 +21,6 @@ import os from distutils.util import strtobool import torch -import numpy as np from diffusers import GaussianDDPMScheduler, UNetModel from diffusers.pipeline_utils import DiffusionPipeline @@ -31,22 +29,7 @@ from models.vision.ddpm.modeling_ddpm import DDPM global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" - - -def get_random_generator(seed): - seed = 1234 - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.enabled = False - generator = torch.Generator() - return generator - +torch.backends.cuda.matmul.allow_tf32 = False def parse_flag_from_env(key, default=False): @@ -132,7 +115,7 @@ class SamplerTesterMixin(unittest.TestCase): @slow def test_sample(self): - generator = get_random_generator(0) + generator = torch.manual_seed(0) # 1. Load models scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") @@ -182,13 +165,12 @@ class SamplerTesterMixin(unittest.TestCase): def test_sample_fast(self): # 1. Load models - generator = get_random_generator(0) + generator = torch.manual_seed(0) scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise - torch.manual_seed(0) image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) # 3. Denoise @@ -218,8 +200,8 @@ class SamplerTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) image_slice = image[0, -1, -3:, -3:].cpu() - import ipdb; ipdb.set_trace() - assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3 + expected_slice = torch.tensor([-0.0304, -0.1895, -0.2436, -0.9837, -0.5422, 0.1931, -0.8175, 0.0862, -0.7783]) + assert (image_slice.flatten() - expected_slice).abs().sum() < 1e-3 class PipelineTesterMixin(unittest.TestCase):