mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Recent Timestep Scheduling Improvements to DDIM Inverse Scheduler (#3865)
* Add Recent Timestep Scheduling Improvements to DDIM Inverse Scheduler Roll timesteps by one to reflect origin-destination semantic discrepancy Restore `set_alpha_to_one` option to handle negative initial timesteps Remove `set_alpha_to_zero` option not used due to previous truncation * Bugfix * Remove unnecessary calls to `detach()` Use `self.image_processor.preprocess` in DiffEdit pipeline functions * Preprocess list input for inverted image latents in diffedit pipeline * Add `timestep_spacing` and `steps_offset` to `DPMSolverMultistepInverseScheduler` * Update expected test results to account for inverting last forward diffusion step * Fix inversion progress bar bug * Add first draft for proper fast tests for DDIMInverseScheduler * Add deprecated DDIMInverseScheduler kwarg to ConfigMixer registry * Fix test failure in DPMMultistepInverseScheduler Invert step specification leads to negative noise variance in SDE-based algs Add first draft for proper fast tests for DPMMultistepInverseScheduler * Update expected test results to account for inverting last forward diffusion step Clean up diffedit fast test
This commit is contained in:
@@ -992,7 +992,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0)
|
||||
image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0)
|
||||
|
||||
# 5. Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -1176,7 +1176,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_images_per_prompt = 1
|
||||
@@ -1201,9 +1201,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
# 7. Noising loop where we obtain the intermediate noised latent image for each timestep.
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
|
||||
inverted_latents = [latents.detach().clone()]
|
||||
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
|
||||
for i, t in enumerate(timesteps[:-1]):
|
||||
inverted_latents = []
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1270,7 +1270,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
# 8. Post-processing
|
||||
image = None
|
||||
if decode_latents:
|
||||
image = self.decode_latents(latents.flatten(0, 1).detach())
|
||||
image = self.decode_latents(latents.flatten(0, 1))
|
||||
|
||||
# 9. Convert to PIL.
|
||||
if decode_latents and output_type == "pil":
|
||||
@@ -1291,7 +1291,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image_latents: torch.FloatTensor = None,
|
||||
image_latents: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
inpaint_strength: Optional[float] = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -1447,7 +1447,13 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device)
|
||||
|
||||
# 6. Preprocess image latents
|
||||
image_latents = preprocess(image_latents)
|
||||
if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents):
|
||||
image_latents = torch.cat(image_latents).detach()
|
||||
elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5:
|
||||
image_latents = image_latents.detach()
|
||||
else:
|
||||
image_latents = self.image_processor.preprocess(image_latents).detach()
|
||||
|
||||
latent_shape = (self.vae.config.latent_channels, latent_height, latent_width)
|
||||
if image_latents.shape[-3:] != latent_shape:
|
||||
raise ValueError(
|
||||
@@ -1458,8 +1464,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape)
|
||||
if image_latents.shape[:2] != (batch_size, len(timesteps)):
|
||||
raise ValueError(
|
||||
f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, "
|
||||
f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps."
|
||||
f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}"
|
||||
f" timesteps, but has batch size {image_latents.shape[0]} with latent images from"
|
||||
f" {image_latents.shape[1]} timesteps."
|
||||
)
|
||||
image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1)
|
||||
image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype)
|
||||
@@ -1468,7 +1475,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
latents = image_latents[0].detach().clone()
|
||||
latents = image_latents[0].clone()
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
|
||||
@@ -1183,8 +1183,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
# 7. Denoising loop where we obtain the cross-attention maps.
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
|
||||
for i, t in enumerate(timesteps[:-1]):
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -90,6 +90,43 @@ def betas_for_alpha_bar(
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`].
|
||||
@@ -126,9 +163,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
timestep_spacing (`str`, default `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
||||
rescale_betas_zero_snr (`bool`, default `False`):
|
||||
whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
|
||||
This can enable the model to generate very bright and dark samples instead of limiting it to samples with
|
||||
medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
"""
|
||||
|
||||
order = 1
|
||||
ignore_for_config = ["kwargs"]
|
||||
_deprecated_kwargs = ["set_alpha_to_zero"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -139,18 +186,20 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_zero: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
clip_sample_range: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs.get("set_alpha_to_one", None) is not None:
|
||||
if kwargs.get("set_alpha_to_zero", None) is not None:
|
||||
deprecation_message = (
|
||||
"The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead."
|
||||
"The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead."
|
||||
)
|
||||
deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False)
|
||||
set_alpha_to_zero = kwargs["set_alpha_to_one"]
|
||||
deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False)
|
||||
set_alpha_to_one = kwargs["set_alpha_to_zero"]
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -166,15 +215,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# At every step in inverted ddim, we are looking into the next alphas_cumprod
|
||||
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
|
||||
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
|
||||
# For the initial step, there is no current alphas_cumprod, and the index is out of bounds
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one
|
||||
# in this case, self.step() just output the predicted noise
|
||||
# or whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]
|
||||
# or whether we use the initial alpha used in training the diffusion model.
|
||||
self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
@@ -215,12 +268,29 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
|
||||
|
||||
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
# Roll timesteps array by one to reflect reversed origin and destination semantics for each step
|
||||
timesteps = np.roll(timesteps, 1)
|
||||
timesteps[0] = int(timesteps[1] - step_ratio)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.timesteps += self.config.steps_offset
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -237,12 +307,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 2. compute alphas, betas
|
||||
# change original implementation to exactly match noise levels for analogous forward process
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = (
|
||||
self.alphas_cumprod[prev_timestep]
|
||||
if prev_timestep < self.config.num_train_timesteps
|
||||
else self.final_alpha_cumprod
|
||||
)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
|
||||
@@ -138,6 +138,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
|
||||
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
|
||||
diffusion ODEs.
|
||||
timestep_spacing (`str`, default `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
||||
steps_offset (`int`, default `0`):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -162,6 +169,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -221,19 +230,41 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
# Clipping the minimum of all lambda(t) for numerical stability.
|
||||
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
||||
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
|
||||
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item()
|
||||
self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
|
||||
timesteps = (
|
||||
np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64)
|
||||
)
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64)
|
||||
)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1)
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', "
|
||||
"'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = timesteps.copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
@@ -397,7 +428,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return epsilon
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -429,23 +459,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
@@ -504,38 +523,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
||||
|
||||
@@ -220,7 +220,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
|
||||
image = sd_pipe.invert(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4833, 0.4696, 0.5574, 0.5194, 0.5248, 0.5638, 0.5040, 0.5423, 0.5072])
|
||||
expected_slice = np.array([0.4823, 0.4783, 0.5638, 0.5201, 0.5247, 0.5644, 0.5029, 0.5404, 0.5062])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@@ -235,7 +235,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
|
||||
image = sd_pipe.invert(**inputs).images
|
||||
image_slice = image[1, -3:, -3:, -1]
|
||||
assert image.shape == (2, 32, 32, 3)
|
||||
expected_slice = np.array([0.6672, 0.5203, 0.4908, 0.4376, 0.4517, 0.5544, 0.4605, 0.4826, 0.5007])
|
||||
expected_slice = np.array([0.6446, 0.5232, 0.4914, 0.4441, 0.4654, 0.5546, 0.4650, 0.4938, 0.5044])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@@ -250,7 +250,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
|
||||
self.assertEqual(image.shape, (2, 32, 32, 3))
|
||||
expected_slice = np.array(
|
||||
[0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799],
|
||||
[0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.5105, 0.5015, 0.4407, 0.4799],
|
||||
)
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
@@ -277,7 +277,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
|
||||
self.assertEqual(image.shape, (2, 32, 32, 3))
|
||||
expected_slice = np.array(
|
||||
[0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799],
|
||||
[0.5305, 0.4673, 0.5314, 0.5308, 0.4886, 0.5279, 0.5142, 0.4724, 0.4892],
|
||||
)
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
135
tests/schedulers/test_scheduler_ddim_inverse.py
Normal file
135
tests/schedulers/test_scheduler_ddim_inverse.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMInverseScheduler
|
||||
|
||||
from .test_schedulers import SchedulerCommonTest
|
||||
|
||||
|
||||
class DDIMInverseSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDIMInverseScheduler,)
|
||||
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"clip_sample": True,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps, eta = 10, 0.0
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample, eta).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 500, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(5)
|
||||
assert torch.equal(scheduler.timesteps, torch.LongTensor([-199, 1, 201, 401, 601]))
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_timestep_spacing(self):
|
||||
for timestep_spacing in ["trailing", "leading"]:
|
||||
self.check_over_configs(timestep_spacing=timestep_spacing)
|
||||
|
||||
def test_rescale_betas_zero_snr(self):
|
||||
for rescale_betas_zero_snr in [True, False]:
|
||||
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
|
||||
|
||||
def test_thresholding(self):
|
||||
self.check_over_configs(thresholding=False)
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 10, 49]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_add_noise_device(self):
|
||||
pass
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 509.1079) < 1e-2
|
||||
assert abs(result_mean.item() - 0.6629) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 1029.129) < 1e-2
|
||||
assert abs(result_mean.item() - 1.3400) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 259.8116) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3383) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 239.055) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3113) < 1e-3
|
||||
266
tests/schedulers/test_scheduler_dpm_multi_inverse.py
Normal file
266
tests/schedulers/test_scheduler_dpm_multi_inverse.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler
|
||||
|
||||
from .test_schedulers import SchedulerCommonTest
|
||||
|
||||
|
||||
class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DPMSolverMultistepInverseScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 25),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"solver_order": 2,
|
||||
"prediction_type": "epsilon",
|
||||
"thresholding": False,
|
||||
"sample_max_value": 1.0,
|
||||
"algorithm_type": "dpmsolver++",
|
||||
"solver_type": "midpoint",
|
||||
"lower_order_final": False,
|
||||
"lambda_min_clipped": -float("inf"),
|
||||
"variance_type": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
output, new_output = sample, sample
|
||||
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
|
||||
output = scheduler.step(residual, t, output, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, scheduler=None, **config):
|
||||
if scheduler is None:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
|
||||
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
|
||||
|
||||
time_step_0 = scheduler.timesteps[5]
|
||||
time_step_1 = scheduler.timesteps[6]
|
||||
|
||||
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [25, 50, 100, 999, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_thresholding(self):
|
||||
self.check_over_configs(thresholding=False)
|
||||
for order in [1, 2, 3]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
algorithm_type="dpmsolver++",
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_solver_order_and_type(self):
|
||||
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
for order in [1, 2, 3]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
sample = self.full_loop(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
assert not torch.isnan(sample).any(), "Samples have nan numbers"
|
||||
|
||||
def test_lower_order_final(self):
|
||||
self.check_over_configs(lower_order_final=True)
|
||||
self.check_over_configs(lower_order_final=False)
|
||||
|
||||
def test_lambda_min_clipped(self):
|
||||
self.check_over_configs(lambda_min_clipped=-float("inf"))
|
||||
self.check_over_configs(lambda_min_clipped=-5.1)
|
||||
|
||||
def test_variance_type(self):
|
||||
self.check_over_configs(variance_type=None)
|
||||
self.check_over_configs(variance_type="learned_range")
|
||||
|
||||
def test_timestep_spacing(self):
|
||||
for timestep_spacing in ["trailing", "leading"]:
|
||||
self.check_over_configs(timestep_spacing=timestep_spacing)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.7047) < 1e-3
|
||||
|
||||
def test_full_loop_no_noise_thres(self):
|
||||
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 19.8933) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 1.5194) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras_and_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 1.7833) < 1e-3
|
||||
|
||||
def test_switch(self):
|
||||
# make sure that iterating over schedulers with same config names gives same results
|
||||
# for defaults
|
||||
scheduler = DPMSolverMultistepInverseScheduler(**self.get_scheduler_config())
|
||||
sample = self.full_loop(scheduler=scheduler)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.7047) < 1e-3
|
||||
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
scheduler = DPMSolverMultistepInverseScheduler.from_config(scheduler.config)
|
||||
|
||||
sample = self.full_loop(scheduler=scheduler)
|
||||
new_result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(new_result_mean.item() - result_mean.item()) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.half()
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
def test_unique_timesteps(self, **config):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
|
||||
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
|
||||
Reference in New Issue
Block a user