1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix random seed

This commit is contained in:
Patrick von Platen
2022-06-07 18:20:14 +02:00
parent 9c4cd06df9
commit 46d20d2d76

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
@@ -30,6 +32,22 @@ global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def get_random_generator(seed):
seed = 1234
random.seed(seed)
os.environ[PYTHONHASHSEED] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
generator = torch.Generator()
return generator
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
@@ -113,8 +131,7 @@ class SamplerTesterMixin(unittest.TestCase):
@slow
def test_sample(self):
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
generator = get_random_generator(0)
# 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
@@ -163,8 +180,7 @@ class SamplerTesterMixin(unittest.TestCase):
def test_sample_fast(self):
# 1. Load models
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
generator = get_random_generator(0)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
@@ -214,16 +230,14 @@ class PipelineTesterMixin(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPM.from_pretrained(tmpdirname)
generator = torch.Generator()
generator = generator.manual_seed(669472945848556)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
generator = generator.manual_seed(669472945848556)
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub(self):
@@ -235,12 +249,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm.noise_scheduler.num_timesteps = 10
ddpm_from_hub.noise_scheduler.num_timesteps = 10
generator = torch.Generator(device=torch_device)
generator = generator.manual_seed(669472945848556)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
generator = generator.manual_seed(669472945848556)
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"