mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix generator 2
This commit is contained in:
@@ -22,6 +22,7 @@ import os
|
||||
from distutils.util import strtobool
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
@@ -35,7 +36,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
def get_random_generator(seed):
|
||||
seed = 1234
|
||||
random.seed(seed)
|
||||
os.environ[‘PYTHONHASHSEED’] = str(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
@@ -176,6 +177,7 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
import ipdb; ipdb.set_trace()
|
||||
assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3
|
||||
|
||||
def test_sample_fast(self):
|
||||
@@ -216,6 +218,7 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
import ipdb; ipdb.set_trace()
|
||||
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user