diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index f9bb8d0af4..b2d5953808 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -992,7 +992,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ) # 4. Preprocess image - image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1176,7 +1176,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 4. Prepare latent variables num_images_per_prompt = 1 @@ -1201,9 +1201,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order - inverted_latents = [latents.detach().clone()] - with self.progress_bar(total=num_inference_steps - 1) as progress_bar: - for i, t in enumerate(timesteps[:-1]): + inverted_latents = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) @@ -1270,7 +1270,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM # 8. Post-processing image = None if decode_latents: - image = self.decode_latents(latents.flatten(0, 1).detach()) + image = self.decode_latents(latents.flatten(0, 1)) # 9. Convert to PIL. if decode_latents and output_type == "pil": @@ -1291,7 +1291,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM self, prompt: Optional[Union[str, List[str]]] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - image_latents: torch.FloatTensor = None, + image_latents: Union[torch.FloatTensor, PIL.Image.Image] = None, inpaint_strength: Optional[float] = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, @@ -1447,7 +1447,13 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) # 6. Preprocess image latents - image_latents = preprocess(image_latents) + if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents): + image_latents = torch.cat(image_latents).detach() + elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5: + image_latents = image_latents.detach() + else: + image_latents = self.image_processor.preprocess(image_latents).detach() + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) if image_latents.shape[-3:] != latent_shape: raise ValueError( @@ -1458,8 +1464,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) if image_latents.shape[:2] != (batch_size, len(timesteps)): raise ValueError( - f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " - f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" + f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" + f" {image_latents.shape[1]} timesteps." ) image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) @@ -1468,7 +1475,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - latents = image_latents[0].detach().clone() + latents = image_latents[0].clone() num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 902bcdd049..960c4369e4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -1183,8 +1183,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order - with self.progress_bar(total=num_inference_steps - 1) as progress_bar: - for i, t in enumerate(timesteps[:-1]): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index c04aabe035..9db83eb992 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -90,6 +90,43 @@ def betas_for_alpha_bar( return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. @@ -126,9 +163,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ order = 1 + ignore_for_config = ["kwargs"] + _deprecated_kwargs = ["set_alpha_to_zero"] @register_to_config def __init__( @@ -139,18 +186,20 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, - set_alpha_to_zero: bool = True, + set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, **kwargs, ): - if kwargs.get("set_alpha_to_one", None) is not None: + if kwargs.get("set_alpha_to_zero", None) is not None: deprecation_message = ( - "The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead." + "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead." ) - deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False) - set_alpha_to_zero = kwargs["set_alpha_to_one"] + deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False) + set_alpha_to_one = kwargs["set_alpha_to_zero"] if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -166,15 +215,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in inverted ddim, we are looking into the next alphas_cumprod - # For the final step, there is no next alphas_cumprod, and the index is out of bounds - # `set_alpha_to_zero` decides whether we set this parameter simply to zero + # For the initial step, there is no current alphas_cumprod, and the index is out of bounds + # `set_alpha_to_one` decides whether we set this parameter simply to one # in this case, self.step() just output the predicted noise - # or whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] + # or whether we use the initial alpha used in training the diffusion model. + self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -215,12 +268,29 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ) self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) + + # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + # Roll timesteps array by one to reflect reversed origin and destination semantics for each step + timesteps = np.roll(timesteps, 1) + timesteps[0] = int(timesteps[1] - step_ratio) self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.config.steps_offset def step( self, @@ -237,12 +307,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): # 2. compute alphas, betas # change original implementation to exactly match noise levels for analogous forward process - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = ( - self.alphas_cumprod[prev_timestep] - if prev_timestep < self.config.num_train_timesteps - else self.final_alpha_cumprod - ) + alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] beta_prod_t = 1 - alpha_prod_t diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index a6736b3544..f5c98a551d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -138,6 +138,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -162,6 +169,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -221,19 +230,41 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item() self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx - timesteps = ( - np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) - ) - if self.use_karras_sigmas: - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " + "'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) + self.sigmas = torch.from_numpy(sigmas) + # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) @@ -397,7 +428,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): return epsilon - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, @@ -429,23 +459,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - x_t = ( - (sigma_t / sigma_s * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + elif "sde" in self.config.algorithm_type: + raise NotImplementedError( + f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." ) return x_t - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], @@ -504,38 +523,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) + elif "sde" in self.config.algorithm_type: + raise NotImplementedError( + f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." + ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 1b17f8b31b..7d94b9f230 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -220,7 +220,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip image = sd_pipe.invert(**inputs).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4833, 0.4696, 0.5574, 0.5194, 0.5248, 0.5638, 0.5040, 0.5423, 0.5072]) + expected_slice = np.array([0.4823, 0.4783, 0.5638, 0.5201, 0.5247, 0.5644, 0.5029, 0.5404, 0.5062]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -235,7 +235,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip image = sd_pipe.invert(**inputs).images image_slice = image[1, -3:, -3:, -1] assert image.shape == (2, 32, 32, 3) - expected_slice = np.array([0.6672, 0.5203, 0.4908, 0.4376, 0.4517, 0.5544, 0.4605, 0.4826, 0.5007]) + expected_slice = np.array([0.6446, 0.5232, 0.4914, 0.4441, 0.4654, 0.5546, 0.4650, 0.4938, 0.5044]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 1de80d60d8..88aeb50dc1 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -250,7 +250,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli self.assertEqual(image.shape, (2, 32, 32, 3)) expected_slice = np.array( - [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799], + [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.5105, 0.5015, 0.4407, 0.4799], ) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -277,7 +277,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli self.assertEqual(image.shape, (2, 32, 32, 3)) expected_slice = np.array( - [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799], + [0.5305, 0.4673, 0.5314, 0.5308, 0.4886, 0.5279, 0.5142, 0.4724, 0.4892], ) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) diff --git a/tests/schedulers/test_scheduler_ddim_inverse.py b/tests/schedulers/test_scheduler_ddim_inverse.py new file mode 100644 index 0000000000..39ee26306c --- /dev/null +++ b/tests/schedulers/test_scheduler_ddim_inverse.py @@ -0,0 +1,135 @@ +import torch + +from diffusers import DDIMInverseScheduler + +from .test_schedulers import SchedulerCommonTest + + +class DDIMInverseSchedulerTest(SchedulerCommonTest): + scheduler_classes = (DDIMInverseScheduler,) + forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50)) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps, eta = 10, 0.0 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + scheduler.set_timesteps(num_inference_steps) + + for t in scheduler.timesteps: + residual = model(sample, t) + sample = scheduler.step(residual, t, sample, eta).prev_sample + + return sample + + def test_timesteps(self): + for timesteps in [100, 500, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(5) + assert torch.equal(scheduler.timesteps, torch.LongTensor([-199, 1, 201, 401, 601])) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_timestep_spacing(self): + for timestep_spacing in ["trailing", "leading"]: + self.check_over_configs(timestep_spacing=timestep_spacing) + + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + ) + + def test_time_indices(self): + for t in [1, 10, 49]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_add_noise_device(self): + pass + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 509.1079) < 1e-2 + assert abs(result_mean.item() - 0.6629) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 1029.129) < 1e-2 + assert abs(result_mean.item() - 1.3400) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 259.8116) < 1e-2 + assert abs(result_mean.item() - 0.3383) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 239.055) < 1e-2 + assert abs(result_mean.item() - 0.3113) < 1e-3 diff --git a/tests/schedulers/test_scheduler_dpm_multi_inverse.py b/tests/schedulers/test_scheduler_dpm_multi_inverse.py new file mode 100644 index 0000000000..61a1d82e0f --- /dev/null +++ b/tests/schedulers/test_scheduler_dpm_multi_inverse.py @@ -0,0 +1,266 @@ +import tempfile + +import torch + +from diffusers import DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler + +from .test_schedulers import SchedulerCommonTest + + +class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): + scheduler_classes = (DPMSolverMultistepInverseScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "prediction_type": "epsilon", + "thresholding": False, + "sample_max_value": 1.0, + "algorithm_type": "dpmsolver++", + "solver_type": "midpoint", + "lower_order_final": False, + "lambda_min_clipped": -float("inf"), + "variance_type": None, + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + + output, new_output = sample, sample + for t in range(time_step, time_step + scheduler.config.solver_order + 1): + output = scheduler.step(residual, t, output, **kwargs).prev_sample + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, scheduler=None, **config): + if scheduler is None: + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] + + time_step_0 = scheduler.timesteps[5] + time_step_1 = scheduler.timesteps[6] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["midpoint", "heun"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + algorithm_type="dpmsolver++", + solver_order=order, + solver_type=solver_type, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_solver_order_and_type(self): + for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for solver_type in ["midpoint", "heun"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) + assert not torch.isnan(sample).any(), "Samples have nan numbers" + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + + def test_timestep_spacing(self): + for timestep_spacing in ["trailing", "leading"]: + self.check_over_configs(timestep_spacing=timestep_spacing) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.7047) < 1e-3 + + def test_full_loop_no_noise_thres(self): + sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 19.8933) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 1.5194) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 1.7833) < 1e-3 + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = DPMSolverMultistepInverseScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.7047) < 1e-3 + + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = DPMSolverMultistepInverseScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + new_result_mean = torch.mean(torch.abs(sample)) + + assert abs(new_result_mean.item() - result_mean.item()) < 1e-3 + + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.half() + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps