mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[schedulers] hanlde dtype in add_noise (#767)
* handle dtype in vae and image2image pipeline * handle dtype in add noise * don't modify vae and pipeline * remove the if
This commit is contained in:
committed by
Patrick von Platen
parent
91ddd2a25b
commit
2bdde4dd83
@@ -300,11 +300,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
if self.alphas_cumprod.device != original_samples.device:
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||
|
||||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
|
||||
@@ -294,11 +294,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
if self.alphas_cumprod.device != original_samples.device:
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||
|
||||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
|
||||
@@ -257,9 +257,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
sigmas = self.sigmas.to(original_samples.device)
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||
deprecate(
|
||||
"timesteps as indices",
|
||||
@@ -273,7 +277,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
|
||||
@@ -400,11 +400,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
if self.alphas_cumprod.device != original_samples.device:
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||
|
||||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
|
||||
Reference in New Issue
Block a user