From c8656ed73c638e51fc2e777a5fd355d69fa5220f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 26 Nov 2025 15:34:22 +0530 Subject: [PATCH 1/3] [docs] put autopipeline after overview and hunyuanimage in images (#12548) put autopipeline after overview and hunyuanimage in images From a88a7b4f03f6b174a991412bc106f468e4a937a1 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Wed, 26 Nov 2025 17:38:41 +0100 Subject: [PATCH 2/3] Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710) * Improve docstrings and type hints in multiple diffusion schedulers * docs: update Imagen Video paper link to Hugging Face Papers. --- .../scheduling_cosine_dpmsolver_multistep.py | 21 ++- .../schedulers/scheduling_deis_multistep.py | 47 +++++- .../scheduling_dpmsolver_multistep.py | 155 ++++++++++++------ .../scheduling_dpmsolver_multistep_inverse.py | 11 ++ .../scheduling_dpmsolver_singlestep.py | 47 +++++- .../scheduling_edm_dpmsolver_multistep.py | 21 ++- .../schedulers/scheduling_sasolver.py | 32 +++- .../schedulers/scheduling_unipc_multistep.py | 47 +++++- 8 files changed, 329 insertions(+), 52 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index 7b11d70493..8d50ee6c7e 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index bf8e1d98d6..45d11c9426 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): raise NotImplementedError("only support log-rho multistep deis now") # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index dee97f39ff..e7ba0ba1f3 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`. + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`): + Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample` + directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`): + Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the + [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the + algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`): + Solver type for the second-order solver. The solver type slightly affects the sample quality, especially + for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. @@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. flow_shift (`float`, *optional*, defaults to 1.0): The shift value for the timestep schedule for flow matching. - final_sigmas_type (`str`, defaults to `"zero"`): + final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - timestep_spacing (`str`, defaults to `"linspace"`): + variance_type (`"learned"` or `"learned_range"`, *optional*): + Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's + output contains the predicted Gaussian variance. + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): @@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to use dynamic shifting for the timestep schedule. + time_shift_type (`"exponential"`, defaults to `"exponential"`): + The type of time shift to apply when using dynamic shifting. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", + algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, @@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): use_lu_lambdas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - timestep_spacing: str = "linspace", + variance_type: Optional[Literal["learned", "learned_range"]] = None, + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", + time_shift_type: Literal["exponential"] = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + mu (`float`, *optional*): + The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and + `time_shift_type="exponential"`. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` @@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): t = t.reshape(sigma.shape) return t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Lu et al. (2022).""" + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """ + Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model + Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022). + + Args: + in_lambdas (`torch.Tensor`): + The input lambda values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted lambda values following the Lu noise schedule. + """ lambda_min: float = in_lambdas[-1].item() lambda_max: float = in_lambdas[0].item() @@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ) return x_t - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): return step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: @@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): + The direct output from the learned diffusion model. + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. - variance_noise (`torch.Tensor`): + variance_noise (`torch.Tensor`, *optional*): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`LEdits++`]. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ @@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 0f734aeb54..2c5d798be0 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 0b271d7eac..c51171cc98 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): raise ValueError(f"Order must be 1, 2, 3, got {order}") # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index eeec588e27..5b1e84dc3a 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index d9054c39c9..9eb37c44ae 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 7dc5f46768..606dfeb239 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): From e6d46123091afd58281dc7487c0f6b67055683b9 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 27 Nov 2025 01:18:57 +0800 Subject: [PATCH 3/3] =?UTF-8?q?Support=20unittest=20for=20Z-image=20?= =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20(#12715)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Support for Z-Image. * Reformatting with make style, black & isort. * Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline. * modified main model forward, freqs_cis left * refactored to add B dim * fixed stack issue * fixed modulation bug * fixed modulation bug * fix bug * remove value_from_time_aware_config * styling * Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor. * Replace padding with pad_sequence; Add gradient checkpointing. * Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that. * Fix Docstring and Make Style. * Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that." This reverts commit fbf26b7ed11d55146103c97740bad4a5f91744e0. * update z-image docstring * Revert attention dispatcher * update z-image docstring * styling * Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility. * Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor. * Remove einop dependency. * remove redundant imports & make fix-copies * fix import * Support for num_images_per_prompt>1; Remove redundant unquote variables. * Fix bugs for num_images_per_prompt with actual batch. * Add unit tests for Z-Image. * Refine unitest and skip for cases needed separate test env; Fix compatibility with unitest in model, mostly precision formating. * Add clean env for test_save_load_float16 separ test; Add Note; Styling. * Update dtype mentioned by yiyi. --------- Co-authored-by: liudongyang --- .../transformers/transformer_z_image.py | 19 +- .../pipelines/z_image/pipeline_z_image.py | 39 +-- tests/pipelines/z_image/test_z_image.py | 306 ++++++++++++++++++ 3 files changed, 336 insertions(+), 28 deletions(-) create mode 100644 tests/pipelines/z_image/test_z_image.py diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index a5c1de682a..3ad835ceee 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -69,7 +69,10 @@ class TimestepEmbedder(nn.Module): def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + weight_dtype = self.mlp[0].weight.dtype + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + t_emb = self.mlp(t_freq) return t_emb @@ -126,6 +129,10 @@ class ZSingleStreamAttnProcessor: dtype = query.dtype query, key = query.to(dtype), key.to(dtype) + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + # Compute joint attention hidden_states = dispatch_attention_fn( query, @@ -306,6 +313,10 @@ class RopeEmbedder: if self.freqs_cis is None: self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] result = [] for i in range(len(self.axes_dims)): @@ -317,6 +328,7 @@ class RopeEmbedder: class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _no_split_modules = ["ZImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers @register_to_config def __init__( @@ -553,8 +565,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr t = t * self.t_scale t = self.t_embedder(t) - adaln_input = t - ( x, cap_feats, @@ -572,6 +582,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr x = torch.cat(x, dim=0) x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) x[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index cc4e9d5201..a4fcacb6eb 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -165,21 +165,16 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[List[torch.FloatTensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, - lora_scale: Optional[float] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt prompt_embeds = self._encode_prompt( prompt=prompt, device=device, - dtype=dtype, - num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, max_sequence_length=max_sequence_length, ) @@ -193,8 +188,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): negative_prompt_embeds = self._encode_prompt( prompt=negative_prompt, device=device, - dtype=dtype, - num_images_per_prompt=num_images_per_prompt, prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, ) @@ -206,12 +199,9 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - num_images_per_prompt: int = 1, prompt_embeds: Optional[List[torch.FloatTensor]] = None, max_sequence_length: int = 512, ) -> List[torch.FloatTensor]: - assert num_images_per_prompt == 1 device = device or self._execution_device if prompt_embeds is not None: @@ -417,8 +407,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): f"Please adjust the width to a multiple of {vae_scale}." ) - assert self.dtype == torch.bfloat16 - dtype = self.dtype device = self._execution_device self._guidance_scale = guidance_scale @@ -434,10 +422,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): else: batch_size = len(prompt_embeds) - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - # If prompt_embeds is provided and prompt is None, skip encoding if prompt_embeds is not None and prompt is None: if self.do_classifier_free_guidance and negative_prompt_embeds is None: @@ -455,11 +439,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - dtype=dtype, device=device, - num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, - lora_scale=lora_scale, ) # 4. Prepare latent variables @@ -475,6 +456,14 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): generator, latents, ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) # 5. Prepare timesteps @@ -523,12 +512,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 if apply_cfg: - latents_typed = latents if latents.dtype == dtype else latents.to(dtype) + latents_typed = latents.to(self.transformer.dtype) latent_model_input = latents_typed.repeat(2, 1, 1, 1) prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds timestep_model_input = timestep.repeat(2) else: - latent_model_input = latents if latents.dtype == dtype else latents.to(dtype) + latent_model_input = latents.to(self.transformer.dtype) prompt_embeds_model_input = prompt_embeds timestep_model_input = timestep @@ -543,11 +532,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): if apply_cfg: # Perform CFG - pos_out = model_out_list[:batch_size] - neg_out = model_out_list[batch_size:] + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] noise_pred = [] - for j in range(batch_size): + for j in range(actual_batch_size): pos = pos_out[j].float() neg = neg_out[j].float() @@ -588,11 +577,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - latents = latents.to(dtype) if output_type == "latent": image = latents else: + latents = latents.to(self.vae.dtype) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] diff --git a/tests/pipelines/z_image/test_z_image.py b/tests/pipelines/z_image/test_z_image.py new file mode 100644 index 0000000000..709473b0db --- /dev/null +++ b/tests/pipelines/z_image/test_z_image.py @@ -0,0 +1,306 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import unittest + +import numpy as np +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + ZImagePipeline, + ZImageTransformer2DModel, +) + +from ...testing_utils import torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + +# Note: Some tests (test_float16_inference, test_save_load_float16) may fail in full suite +# due to RopeEmbedder cache state pollution between tests. They pass when run individually. +# This is a known test isolation issue, not a functional bug. + + +class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ZImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = ZImageTransformer2DModel( + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=32, + n_layers=2, + n_refiner_layers=1, + n_heads=2, + n_kv_heads=2, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=16, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[8, 4, 4], + axes_lens=[256, 32, 32], + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + block_out_channels=[32, 64], + layers_per_block=1, + latent_channels=16, + norm_num_groups=32, + sample_size=32, + scaling_factor=0.3611, + shift_factor=0.1159, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen3Config( + hidden_size=16, + intermediate_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + text_encoder = Qwen3Model(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "cfg_normalization": False, + "cfg_truncation": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + # fmt: off + expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732]) + # fmt: on + + generated_slice = generated_image.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=5e-2)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_num_images_per_prompt(self): + import inspect + + sig = inspect.signature(self.pipeline_class.__call__) + + if "num_images_per_prompt" not in sig.parameters: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + del pipe + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling (standard AutoencoderKL doesn't accept parameters) + pipe.vae.enable_tiling() + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4): + # Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance + super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference) + + def test_group_offloading_inference(self): + # Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine. + self.skipTest("Using test_pipeline_level_group_offloading_inference instead") + + def test_save_load_float16(self, expected_max_diff=1e-2): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + super().test_save_load_float16(expected_max_diff=expected_max_diff)