mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Match the generator device to the pipeline for DDPM and DDIM (#1222)
* Match the generator device to the pipeline for DDPM and DDIM * style * fix * update values * fix fast tests * trigger slow tests * deprecate * last value fixes * mps fixes
This commit is contained in:
@@ -11,10 +11,12 @@ import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import deprecate
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from packaging import version
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
@@ -28,6 +30,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
diffusers_version = version.parse(version.parse(__version__).base_version)
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
@@ -406,7 +409,11 @@ def main(args):
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
|
||||
if diffusers_version < version.parse("0.8.0"):
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=pipeline.device).manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(
|
||||
generator=generator,
|
||||
|
||||
@@ -12,12 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...utils import deprecate
|
||||
|
||||
|
||||
class DDIMPipeline(DiffusionPipeline):
|
||||
@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be set to `None`. "
|
||||
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.11.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
|
||||
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_kwargs = {}
|
||||
if accepts_use_clipped_model_output:
|
||||
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
|
||||
image = self.scheduler.step(
|
||||
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
|
||||
).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
self.scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be set to `None`. "
|
||||
f'Please use `torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.11.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.randn(
|
||||
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
|
||||
).to(model_output.device)
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
|
||||
device = model_output.device
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
||||
variance_noise = variance_noise.to(device)
|
||||
else:
|
||||
variance_noise = torch.randn(
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
)
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.to(device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
|
||||
)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_ema_bedroom(self):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
|
||||
expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_cifar10(self):
|
||||
@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
|
||||
expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler()
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.to(device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
|
||||
)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
|
||||
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
return model
|
||||
|
||||
def test_inference_superresolution(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vqvae = self.dummy_vq_model
|
||||
|
||||
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.to(device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
init_image = self.dummy_image.to(torch_device)
|
||||
init_image = self.dummy_image.to(device)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176])
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe = pipe.to(torch_device)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
generator_2 = generator.manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe = pipe.to(torch_device)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
|
||||
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
generator_2 = generator.manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase):
|
||||
def test_load_no_safety_checker_default_locally(self):
|
||||
prompt = "hello"
|
||||
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe = pipe.to(torch_device)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
generator_2 = generator.manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||
new_ddpm.to(torch_device)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
images = pipe(generator=generator, output_type="numpy").images
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
assert isinstance(images, np.ndarray)
|
||||
@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
# Make sure the test passes for different values of random seed
|
||||
@parameterized.expand([(0,), (4,)])
|
||||
def test_ddpm_ddim_equality(self, seed):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler()
|
||||
ddim_scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
ddpm_image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
ddim_image = ddim(
|
||||
generator=generator,
|
||||
num_inference_steps=1000,
|
||||
eta=1.0,
|
||||
output_type="numpy",
|
||||
use_clipped_model_output=True, # Need this to make DDIM match DDPM
|
||||
).images
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
||||
|
||||
# Make sure the test passes for different values of random seed
|
||||
@parameterized.expand([(0,), (4,)])
|
||||
def test_ddpm_ddim_equality_batched(self, seed):
|
||||
def test_ddpm_ddim_equality_batched(self):
|
||||
seed = 0
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
ddim_images = ddim(
|
||||
batch_size=4,
|
||||
batch_size=2,
|
||||
generator=generator,
|
||||
num_inference_steps=1000,
|
||||
eta=1.0,
|
||||
|
||||
Reference in New Issue
Block a user