diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 09614171f1..2a11e1c00f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + + import random import tempfile import unittest @@ -30,6 +32,22 @@ global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" +def get_random_generator(seed): + seed = 1234 + random.seed(seed) + os.environ[‘PYTHONHASHSEED’] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + generator = torch.Generator() + return generator + + + def parse_flag_from_env(key, default=False): try: value = os.environ[key] @@ -113,8 +131,7 @@ class SamplerTesterMixin(unittest.TestCase): @slow def test_sample(self): - generator = torch.Generator() - generator = generator.manual_seed(6694729458485568) + generator = get_random_generator(0) # 1. Load models scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") @@ -163,8 +180,7 @@ class SamplerTesterMixin(unittest.TestCase): def test_sample_fast(self): # 1. Load models - generator = torch.Generator() - generator = generator.manual_seed(6694729458485568) + generator = get_random_generator(0) scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) @@ -214,16 +230,14 @@ class PipelineTesterMixin(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: ddpm.save_pretrained(tmpdirname) new_ddpm = DDPM.from_pretrained(tmpdirname) - - generator = torch.Generator() - generator = generator.manual_seed(669472945848556) + + generator = torch.manual_seed(0) image = ddpm(generator=generator) - generator = generator.manual_seed(669472945848556) + generator = generator.manual_seed(0) new_image = new_ddpm(generator=generator) assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - @slow def test_from_pretrained_hub(self): @@ -235,12 +249,10 @@ class PipelineTesterMixin(unittest.TestCase): ddpm.noise_scheduler.num_timesteps = 10 ddpm_from_hub.noise_scheduler.num_timesteps = 10 - - generator = torch.Generator(device=torch_device) - generator = generator.manual_seed(669472945848556) + generator = torch.manual_seed(0) image = ddpm(generator=generator) - generator = generator.manual_seed(669472945848556) + generator = generator.manual_seed(0) new_image = ddpm_from_hub(generator=generator) assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"