mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix random seed
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -30,6 +32,22 @@ global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_random_generator(seed):
|
||||
seed = 1234
|
||||
random.seed(seed)
|
||||
os.environ[‘PYTHONHASHSEED’] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
generator = torch.Generator()
|
||||
return generator
|
||||
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
try:
|
||||
value = os.environ[key]
|
||||
@@ -113,8 +131,7 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_sample(self):
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(6694729458485568)
|
||||
generator = get_random_generator(0)
|
||||
|
||||
# 1. Load models
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
|
||||
@@ -163,8 +180,7 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
def test_sample_fast(self):
|
||||
# 1. Load models
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(6694729458485568)
|
||||
generator = get_random_generator(0)
|
||||
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
|
||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
@@ -214,16 +230,14 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPM.from_pretrained(tmpdirname)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm(generator=generator)
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
@@ -235,12 +249,10 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
ddpm.noise_scheduler.num_timesteps = 10
|
||||
ddpm_from_hub.noise_scheduler.num_timesteps = 10
|
||||
|
||||
|
||||
generator = torch.Generator(device=torch_device)
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm(generator=generator)
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
Reference in New Issue
Block a user