1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[schedulers] hanlde dtype in add_noise (#767)

* handle dtype in vae and image2image pipeline

* handle dtype in add noise

* don't modify vae and pipeline

* remove the if
This commit is contained in:
Suraj Patil
2022-10-07 16:44:19 +02:00
committed by Patrick von Platen
parent 91ddd2a25b
commit 2bdde4dd83
4 changed files with 16 additions and 18 deletions

View File

@@ -300,11 +300,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

View File

@@ -294,11 +294,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

View File

@@ -257,9 +257,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
deprecate(
"timesteps as indices",
@@ -273,7 +277,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

View File

@@ -400,11 +400,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()