From 5da71f8fa3e1439437ae223118e6d81073baefed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Jun 2022 16:22:12 +0000 Subject: [PATCH] fix generator 2 --- tests/test_modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2a11e1c00f..7cbf5d272a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,6 +22,7 @@ import os from distutils.util import strtobool import torch +import numpy as np from diffusers import GaussianDDPMScheduler, UNetModel from diffusers.pipeline_utils import DiffusionPipeline @@ -35,7 +36,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu" def get_random_generator(seed): seed = 1234 random.seed(seed) - os.environ[‘PYTHONHASHSEED’] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -176,6 +177,7 @@ class SamplerTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) image_slice = image[0, -1, -3:, -3:].cpu() + import ipdb; ipdb.set_trace() assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3 def test_sample_fast(self): @@ -216,6 +218,7 @@ class SamplerTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) image_slice = image[0, -1, -3:, -3:].cpu() + import ipdb; ipdb.set_trace() assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3