From 3d02cd543ef3101d821cb09c8fcab23c6e7ead33 Mon Sep 17 00:00:00 2001 From: David Lacalle Castillo <41203448+WaterKnight1998@users.noreply.github.com> Date: Mon, 8 Dec 2025 13:12:17 +0100 Subject: [PATCH 1/6] [PRX] Improve model compilation (#12787) * Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT). * Apply style fixes --------- Co-authored-by: github-actions[bot] Co-authored-by: Sayak Paul --- .../models/transformers/transformer_prx.py | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 18ec650bac..a87c120fdc 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn -from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging @@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)` is the number of patches. """ - return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + b, c, h, w = img.shape + p = patch_size + + # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions + img = img.reshape(b, c, h // p, p, w // p, p) + + # Permute to (B, H//p, W//p, C, p, p) using einsum + # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width + img = torch.einsum("nchpwq->nhwcpq", img) + + # Flatten to (B, L, C * p * p) + img = img.reshape(b, -1, c * p * p) + return img def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: @@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te Reconstructed image tensor of shape `(B, C, H, W)`. """ if isinstance(shape, tuple): - shape = shape[-2:] + h, w = shape[-2:] elif isinstance(shape, torch.Tensor): - shape = (int(shape[0]), int(shape[1])) + h, w = (int(shape[0]), int(shape[1])) else: raise NotImplementedError(f"shape type {type(shape)} not supported") - return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + b, l, d = seq.shape + p = patch_size + c = d // (p * p) + + # Reshape back to grid structure: (B, H//p, W//p, C, p, p) + seq = seq.reshape(b, h // p, w // p, c, p, p) + + # Permute back to image layout: (B, C, H//p, p, W//p, p) + # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width + seq = torch.einsum("nhwcpq->nchpwq", seq) + + # Final reshape to (B, C, H, W) + seq = seq.reshape(b, c, h, w) + return seq class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): From 54fa0745c34be810e0f51d6bd528b16af4e6abe5 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 8 Dec 2025 17:58:57 +0100 Subject: [PATCH 2/6] Improve docstrings and type hints in scheduling_dpmsolver_singlestep.py (#12798) feat: add flow sigmas, dynamic shifting, and refine type hints in DPMSolverSinglestepScheduler --- .../scheduling_dpmsolver_singlestep.py | 124 +++++++++++------- 1 file changed, 74 insertions(+), 50 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 55c9fb6e73..4916e1abb5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -86,42 +86,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): + trained_betas (`np.ndarray` or `List[float]`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): + solver_order (`int`, defaults to `2`): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). + `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): + dynamic_thresholding_ratio (`float`, defaults to `0.995`): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): + sample_max_value (`float`, defaults to `1.0`): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` + algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): + solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): + lower_order_final (`bool`, defaults to `False`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): @@ -132,15 +132,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. - final_sigmas_type (`str`, *optional*, defaults to `"zero"`): + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to `1.0`): + The flow shift parameter for flow-based models. + final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. + variance_type (`"learned"` or `"learned_range"`, *optional*): + Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's + output contains the predicted Gaussian variance. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to use dynamic shifting for the noise schedule. + time_shift_type (`"exponential"`, defaults to `"exponential"`): + The type of time shifting to apply. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -152,27 +160,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", + algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, + variance_type: Optional[Literal["learned", "learned_range"]] = None, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", - ): + time_shift_type: Literal["exponential"] = "exponential", + ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: @@ -242,6 +250,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + + Returns: + `List[int]`: + The list of solver orders for each timestep. """ steps = num_inference_steps order = self.config.solver_order @@ -276,21 +288,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return orders @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -302,19 +322,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + mu (`float`, *optional*): + The mu parameter for dynamic shifting. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is @@ -453,7 +475,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -490,7 +512,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. @@ -512,7 +534,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). @@ -637,7 +659,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -733,7 +755,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -797,7 +819,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -908,7 +930,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1030,8 +1052,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, - order: int = None, + sample: Optional[torch.Tensor] = None, + order: Optional[int] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1125,7 +1147,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. @@ -1146,7 +1168,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -1156,11 +1178,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + generator (`torch.Generator`, *optional*): + A random number generator for stochastic sampling. + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: @@ -1277,5 +1301,5 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps From 07ea0786e8e1bb58a6ee25797f7ced0165f720af Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 9 Dec 2025 08:08:41 -1000 Subject: [PATCH 3/6] [Modular]z-image (#12808) * initiL * up up * fix: z_image -> z-image * style * copy * fix more * some docstring fix --- src/diffusers/__init__.py | 4 + src/diffusers/modular_pipelines/__init__.py | 5 + .../modular_pipelines/modular_pipeline.py | 1 + .../modular_pipelines/wan/encoders.py | 6 +- .../modular_pipelines/z_image/__init__.py | 57 ++ .../z_image/before_denoise.py | 621 ++++++++++++++++++ .../modular_pipelines/z_image/decoders.py | 91 +++ .../modular_pipelines/z_image/denoise.py | 310 +++++++++ .../modular_pipelines/z_image/encoders.py | 344 ++++++++++ .../z_image/modular_blocks.py | 191 ++++++ .../z_image/modular_pipeline.py | 72 ++ .../dummy_torch_and_transformers_objects.py | 30 + 12 files changed, 1730 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/modular_pipelines/z_image/__init__.py create mode 100644 src/diffusers/modular_pipelines/z_image/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/z_image/decoders.py create mode 100644 src/diffusers/modular_pipelines/z_image/denoise.py create mode 100644 src/diffusers/modular_pipelines/z_image/encoders.py create mode 100644 src/diffusers/modular_pipelines/z_image/modular_blocks.py create mode 100644 src/diffusers/modular_pipelines/z_image/modular_pipeline.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6df4ad4894..e69d334fdb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -419,6 +419,8 @@ else: "Wan22AutoBlocks", "WanAutoBlocks", "WanModularPipeline", + "ZImageAutoBlocks", + "ZImageModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -1124,6 +1126,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline, + ZImageAutoBlocks, + ZImageModularPipeline, ) from .pipelines import ( AllegroPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 252b9f33df..dea9da0269 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -60,6 +60,10 @@ else: "QwenImageEditPlusModularPipeline", "QwenImageEditPlusAutoBlocks", ] + _import_structure["z_image"] = [ + "ZImageAutoBlocks", + "ZImageModularPipeline", + ] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -91,6 +95,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline + from .z_image import ZImageAutoBlocks, ZImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index a6336de71a..bba89e6121 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -61,6 +61,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), + ("z-image", "ZImageModularPipeline"), ] ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index dc49df8eab..4fd69c6ca6 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -530,6 +530,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks): device = components._execution_device dtype = torch.float32 + vae_dtype = components.vae.dtype height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -555,7 +556,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks): vae=components.vae, generator=block_state.generator, device=device, - dtype=dtype, + dtype=vae_dtype, latent_channels=components.num_channels_latents, ) @@ -627,6 +628,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): device = components._execution_device dtype = torch.float32 + vae_dtype = components.vae.dtype height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -659,7 +661,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): vae=components.vae, generator=block_state.generator, device=device, - dtype=dtype, + dtype=vae_dtype, latent_channels=components.num_channels_latents, ) diff --git a/src/diffusers/modular_pipelines/z_image/__init__.py b/src/diffusers/modular_pipelines/z_image/__init__.py new file mode 100644 index 0000000000..c8a8c14396 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["decoders"] = ["ZImageVaeDecoderStep"] + _import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "ZImageAutoBlocks", + ] + _import_structure["modular_pipeline"] = ["ZImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .decoders import ZImageVaeDecoderStep + from .encoders import ZImageTextEncoderStep + from .modular_blocks import ( + ALL_BLOCKS, + ZImageAutoBlocks, + ) + from .modular_pipeline import ZImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/z_image/before_denoise.py b/src/diffusers/modular_pipelines/z_image/before_denoise.py new file mode 100644 index 0000000000..35ea768f12 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/before_denoise.py @@ -0,0 +1,621 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> Tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 dimensions. + Expected shapes: [batch, channels, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension. + By default, it is 16 + Returns: + Tuple[int, int]: The calculated image dimensions as (height, width) + """ + latent_height, latent_width = latents.shape[2:] + height = latent_height * vae_scale_factor_spatial // 2 + width = latent_width * vae_scale_factor_spatial // 2 + + return height, width + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageTextInputStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=List[torch.Tensor], + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if not isinstance(block_state.prompt_embeds, list): + raise ValueError( + f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}." + ) + if not isinstance(block_state.negative_prompt_embeds, list): + raise ValueError( + f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}." + ) + if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds): + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but" + f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`" + f" {len(block_state.negative_prompt_embeds)}." + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = len(block_state.prompt_embeds) + block_state.dtype = block_state.prompt_embeds[0].dtype + + if block_state.num_images_per_prompt > 1: + prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)] + block_state.prompt_embeds = prompt_embeds + + if block_state.negative_prompt_embeds is not None: + negative_prompt_embeds = [ + npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt) + ] + block_state.negative_prompt_embeds = negative_prompt_embeds + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageAdditionalInputsStep(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (List[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_latents"]. + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_latents (default behavior) ZImageAdditionalInputsStep() + + # Configure to process multiple image latent inputs + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + def check_inputs(self, components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("num_inference_steps", default=9), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3] + image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify + + mu = calculate_shift( + image_seq_len, + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + components.scheduler.sigma_min = 0.0 + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=block_state.sigmas, + mu=mu, + ) + + self.set_block_state(state, block_state) + return components, state + + +class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("strength", default=0.6), + ] + + def check_inputs(self, components, block_state): + if block_state.strength < 0.0 or block_state.strength > 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}") + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps) + + t_start = int(max(block_state.num_inference_steps - init_timestep, 0)) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + block_state.timesteps = timesteps + block_state.num_inference_steps = block_state.num_inference_steps - t_start + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("image_latents", required=True), + InputParam("timesteps", required=True), + ] + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/decoders.py b/src/diffusers/modular_pipelines/z_image/decoders.py new file mode 100644 index 0000000000..cdb6a2e5ea --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/decoders.py @@ -0,0 +1,91 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageVaeDecoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + ), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents.to(vae_dtype) + latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py new file mode 100644 index 0000000000..ec815f77ad --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -0,0 +1,310 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents = block_state.latents.unsqueeze(2).to( + block_state.dtype + ) # [batch_size, num_channels, 1, height, width] + block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width] + + timestep = t.expand(latents.shape[0]).to(block_state.dtype) + timestep = (1000 - timestep) / 1000 + block_state.timestep = timestep + return components, block_state + + +class ZImageLoopDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + + def _convert_dtype(v, dtype): + if isinstance(v, torch.Tensor): + return v.to(dtype) + elif isinstance(v, list): + return [_convert_dtype(t, dtype) for t in v] + return v + + cond_kwargs = { + k: _convert_dtype(v, block_state.dtype) + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + model_out_list = components.transformer( + x=block_state.latent_model_input, + t=block_state.timestep, + return_dict=False, + **cond_kwargs, + )[0] + noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) + guider_state_batch.noise_pred = -noise_pred + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class ZImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageDenoiseStep(ZImageDenoiseLoopWrapper): + block_classes = [ + ZImageLoopBeforeDenoiser, + ZImageLoopDenoiser( + guider_input_fields={ + "cap_feats": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + ZImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `ZImageLoopBeforeDenoiser`\n" + " - `ZImageLoopDenoiser`\n" + " - `ZImageLoopAfterDenoiser`\n" + "This block supports text-to-image and image-to-image tasks for Z-Image." + ) diff --git a/src/diffusers/modular_pipelines/z_image/encoders.py b/src/diffusers/modular_pipelines/z_image/encoders.py new file mode 100644 index 0000000000..f5769fe2de --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/encoders.py @@ -0,0 +1,344 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import PIL +import torch +from transformers import Qwen2Tokenizer, Qwen3Model + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +if is_ftfy_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_qwen_prompt_embeds( + text_encoder: Qwen3Model, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: torch.device, + max_sequence_length: int = 512, +) -> List[torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + prompt_embeds_list = [] + + for i in range(len(prompt_embeds)): + prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]]) + + return prompt_embeds_list + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + image_tensor: torch.Tensor, + vae: AutoencoderKL, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(image_tensor, torch.Tensor): + raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.") + + if isinstance(generator, list) and len(generator) != image_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}." + ) + + image_tensor = image_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i]) + for i in range(image_tensor.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class ZImageTextEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=List[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=List[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: Optional[torch.device] = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class ZImageVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + image_tensor = components.image_processor.preprocess( + image, height=block_state.height, width=block_state.width + ).to(device=device, dtype=dtype) + + block_state.image_latents = encode_vae_image( + image_tensor=image_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/modular_blocks.py b/src/diffusers/modular_pipelines/z_image/modular_blocks.py new file mode 100644 index 0000000000..a7c520301a --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/modular_blocks.py @@ -0,0 +1,191 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + ZImageAdditionalInputsStep, + ZImagePrepareLatentsStep, + ZImagePrepareLatentswithImageStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImageTextInputStep, +) +from .decoders import ZImageVaeDecoderStep +from .denoise import ( + ZImageDenoiseStep, +) +from .encoders import ( + ZImageTextEncoderStep, + ZImageVaeImageEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# z-image +# text2image +class ZImageCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + ZImageTextInputStep, + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `ZImagePrepareLatentsStep` is used to prepare the latents\n" + + " - `ZImageSetTimestepsStep` is used to set the timesteps\n" + + " - `ZImageDenoiseStep` is used to denoise the latents\n" + ) + + +# z-image: image2image +## denoise +class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + ZImageTextInputStep, + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImagePrepareLatentswithImageStep, + ZImageDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "set_timesteps_with_strength", + "prepare_latents_with_image", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `ZImagePrepareLatentsStep` is used to prepare the latents\n" + + " - `ZImageSetTimestepsStep` is used to set the timesteps\n" + + " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n" + + " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n" + + " - `ZImageDenoiseStep` is used to denoise the latents\n" + ) + + +## auto blocks +class ZImageAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + ZImageImage2ImageCoreDenoiseStep, + ZImageCoreDenoiseStep, + ] + block_names = ["image2image", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2image and image2image tasks." + " - `ZImageCoreDenoiseStep` (text2image) for text2image tasks." + " - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks." + + " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n" + + " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n" + ) + + +class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): + block_classes = [ZImageVaeImageEncoderStep] + block_names = ["vae_image_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self) -> str: + return "Vae Image Encoder step that encode the image to generate the image latents" + +"This is an auto pipeline block that works for image2image tasks." + +" - `ZImageVaeImageEncoderStep` is used when `image` is provided." + +" - if `image` is not provided, step will be skipped." + + +class ZImageAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + ZImageTextEncoderStep, + ZImageAutoVaeImageEncoderStep, + ZImageAutoDenoiseStep, + ZImageVaeDecoderStep, + ] + block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n" + +" - for text-to-image generation, all you need to provide is `prompt`\n" + +" - for image-to-image generation, you need to provide `image`\n" + +" - if `image` is not provided, step will be skipped." + + +# presets +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("input", ZImageTextInputStep), + ("prepare_latents", ZImagePrepareLatentsStep), + ("set_timesteps", ZImageSetTimestepsStep), + ("denoise", ZImageDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + +IMAGE2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("vae_image_encoder", ZImageVaeImageEncoderStep), + ("input", ZImageTextInputStep), + ("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])), + ("prepare_latents", ZImagePrepareLatentsStep), + ("set_timesteps", ZImageSetTimestepsStep), + ("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep), + ("prepare_latents_with_image", ZImagePrepareLatentswithImageStep), + ("denoise", ZImageDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("vae_image_encoder", ZImageAutoVaeImageEncoderStep), + ("denoise", ZImageAutoDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "image2image": IMAGE2IMAGE_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/z_image/modular_pipeline.py b/src/diffusers/modular_pipelines/z_image/modular_pipeline.py new file mode 100644 index 0000000000..f1d8e53a36 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/modular_pipeline.py @@ -0,0 +1,72 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import ZImageLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageModularPipeline( + ModularPipeline, + ZImageLoraLoaderMixin, +): + """ + A ModularPipeline for Z-Image. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "ZImageAutoBlocks" + + @property + def default_height(self): + return 1024 + + @property + def default_width(self): + return 1024 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor_spatial = 16 + if hasattr(self, "image_processor") and self.image_processor is not None: + vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor + return vae_scale_factor_spatial + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 79a21d2ac6..da64742518 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -227,6 +227,36 @@ class WanModularPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class ZImageAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ZImageModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 8b4722de57a9a2646466b8bb7095c4fd465193fa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 10 Dec 2025 04:08:30 +0800 Subject: [PATCH 4/6] Fix Qwen Edit Plus modular for multi-image input (#12601) * try to fix qwen edit plus multi images (modular) * up * up * test * up * up --- .../qwenimage/before_denoise.py | 32 +++++++- .../modular_pipelines/qwenimage/encoders.py | 82 +++++++++++++++---- .../modular_pipelines/qwenimage/inputs.py | 76 +++++++++++++++-- .../qwenimage/modular_blocks.py | 79 +++++++++++++++--- .../qwen/test_modular_pipeline_qwenimage.py | 13 ++- 5 files changed, 247 insertions(+), 35 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 0e470332c6..bd92d40353 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -610,7 +610,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): block_state = self.get_block_state(state) # for edit, image size can be different from the target size (height/width) - block_state.img_shapes = [ [ ( @@ -640,6 +639,37 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): return components, state +class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): + model_name = "qwenimage-edit-plus" + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + block_state.img_shapes = [ + [ + (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), + *[ + (1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2) + for vae_height, vae_width in zip(block_state.image_height, block_state.image_width) + ], + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + ## ControlNet inputs for denoiser class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 3b56981e52..b126a368bf 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -330,7 +330,7 @@ class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep): output_name: str = "resized_image", vae_image_output_name: str = "vae_image", ): - """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + """Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio. This block resizes an input image or a list input images and exposes the resized result under configurable input and output names. Use this when you need to wire the resize step to different image fields (e.g., @@ -809,9 +809,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks): @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam(name="processed_image"), - ] + return [OutputParam(name="processed_image")] @staticmethod def check_inputs(height, width, vae_scale_factor): @@ -851,7 +849,10 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks): class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): model_name = "qwenimage-edit-plus" - vae_image_size = 1024 * 1024 + + def __init__(self): + self.vae_image_size = 1024 * 1024 + super().__init__() @property def description(self) -> str: @@ -868,6 +869,7 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): if block_state.vae_image is None and block_state.image is None: raise ValueError("`vae_image` and `image` cannot be None at the same time") + vae_image_sizes = None if block_state.vae_image is None: image = block_state.image self.check_inputs( @@ -879,12 +881,19 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): image=image, height=height, width=width ) else: - width, height = block_state.vae_image[0].size - image = block_state.vae_image + # QwenImage Edit Plus can allow multiple input images with varied resolutions + processed_images = [] + vae_image_sizes = [] + for img in block_state.vae_image: + width, height = img.size + vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height) + vae_image_sizes.append((vae_width, vae_height)) + processed_images.append( + components.image_processor.preprocess(image=img, height=vae_height, width=vae_width) + ) + block_state.processed_image = processed_images - block_state.processed_image = components.image_processor.preprocess( - image=image, height=height, width=width - ) + block_state.vae_image_sizes = vae_image_sizes self.set_block_state(state, block_state) return components, state @@ -926,17 +935,12 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): @property def expected_components(self) -> List[ComponentSpec]: - components = [ - ComponentSpec("vae", AutoencoderKLQwenImage), - ] + components = [ComponentSpec("vae", AutoencoderKLQwenImage)] return components @property def inputs(self) -> List[InputParam]: - inputs = [ - InputParam(self._image_input_name, required=True), - InputParam("generator"), - ] + inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")] return inputs @property @@ -974,6 +978,50 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): return components, state +class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep): + model_name = "qwenimage-edit-plus" + + @property + def intermediate_outputs(self) -> List[OutputParam]: + # Each reference image latent can have varied resolutions hence we return this as a list. + return [ + OutputParam( + self._image_latents_output_name, + type_hint=List[torch.Tensor], + description="The latents representing the reference image(s).", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.vae.dtype + + image = getattr(block_state, self._image_input_name) + + # Encode image into latents + image_latents = [] + for img in image: + image_latents.append( + encode_vae_image( + image=img, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + ) + + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 2b229c040b..6e656e4848 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -224,11 +224,7 @@ class QwenImageTextInputsStep(ModularPipelineBlocks): class QwenImageInputsDynamicStep(ModularPipelineBlocks): model_name = "qwenimage" - def __init__( - self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], - ): + def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" This step handles multiple common tasks to prepare inputs for the denoising step: @@ -372,6 +368,76 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks): return components, state +class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep): + model_name = "qwenimage-edit-plus" + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"), + OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # Each image latent can have different size in QwenImage Edit Plus. + image_heights = [] + image_widths = [] + packed_image_latent_tensors = [] + + for img_latent_tensor in image_latent_tensor: + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) + image_heights.append(height) + image_widths.append(width) + + # 2. Patchify the image latent tensor + img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) + + # 3. Expand batch size + img_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=img_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + packed_image_latent_tensors.append(img_latent_tensor) + + packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + block_state.image_height = image_heights + block_state.image_width = image_widths + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + block_state.height = block_state.height or image_heights[-1] + block_state.width = block_state.width or image_widths[-1] + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageControlNetInputsStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 4198941643..55a7ae328f 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -18,6 +18,7 @@ from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, + QwenImageEditPlusRoPEInputsStep, QwenImageEditRoPEInputsStep, QwenImagePrepareLatentsStep, QwenImagePrepareLatentsWithStrengthStep, @@ -40,6 +41,7 @@ from .encoders import ( QwenImageEditPlusProcessImagesInputStep, QwenImageEditPlusResizeDynamicStep, QwenImageEditPlusTextEncoderStep, + QwenImageEditPlusVaeEncoderDynamicStep, QwenImageEditResizeDynamicStep, QwenImageEditTextEncoderStep, QwenImageInpaintProcessImagesInputStep, @@ -47,7 +49,12 @@ from .encoders import ( QwenImageTextEncoderStep, QwenImageVaeEncoderDynamicStep, ) -from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep +from .inputs import ( + QwenImageControlNetInputsStep, + QwenImageEditPlusInputsDynamicStep, + QwenImageInputsDynamicStep, + QwenImageTextInputsStep, +) logger = logging.get_logger(__name__) @@ -904,13 +911,13 @@ QwenImageEditPlusVaeEncoderBlocks = InsertableDict( [ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents ] ) class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit-plus" block_classes = QwenImageEditPlusVaeEncoderBlocks.values() block_names = QwenImageEditPlusVaeEncoderBlocks.keys() @@ -919,25 +926,62 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): return "Vae encoder step that encode the image inputs into their latent representations." +#### QwenImage Edit Plus input blocks +QwenImageEditPlusInputBlocks = InsertableDict( + [ + ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings + ( + "additional_inputs", + QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]), + ), + ] +) + + +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = QwenImageEditPlusInputBlocks.values() + block_names = QwenImageEditPlusInputBlocks.keys() + + #### QwenImage Edit Plus presets EDIT_PLUS_BLOCKS = InsertableDict( [ ("text_encoder", QwenImageEditPlusVLEncoderStep()), ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), - ("input", QwenImageEditInputStep()), + ("input", QwenImageEditPlusInputStep()), ("prepare_latents", QwenImagePrepareLatentsStep()), ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) +QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), + ] +) + + +class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values() + block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." + + # auto before_denoise step for edit tasks class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "qwenimage-edit-plus" - block_classes = [QwenImageEditBeforeDenoiseStep] + block_classes = [QwenImageEditPlusBeforeDenoiseStep] block_names = ["edit"] block_trigger_inputs = ["image_latents"] @@ -946,7 +990,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): return ( "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + "This is an auto pipeline block that works for edit (img2img) task.\n" - + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + + " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + " - if `image_latents` is not provided, step will be skipped." ) @@ -955,9 +999,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [ - QwenImageEditPlusVaeEncoderStep, - ] + block_classes = [QwenImageEditPlusVaeEncoderStep] block_names = ["edit"] block_trigger_inputs = ["image"] @@ -974,10 +1016,25 @@ class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): ## 3.3 QwenImage-Edit/auto blocks & presets +class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks): + block_classes = [QwenImageEditPlusInputStep] + block_names = ["edit"] + block_trigger_inputs = ["image_latents"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step.\n" + + " It is an auto pipeline block that works for edit task.\n" + + " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n" + + " - if `image_latents` is not provided, step will be skipped." + ) + + class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditAutoInputStep, + QwenImageEditPlusAutoInputStep, QwenImageEditPlusAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, ] diff --git a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py index 8d7600781b..f4bd27b7ea 100644 --- a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py +++ b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py @@ -26,6 +26,7 @@ from diffusers.modular_pipelines import ( QwenImageModularPipeline, ) +from ...testing_utils import torch_device from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin @@ -104,6 +105,16 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul inputs["image"] = PIL.Image.new("RGB", (32, 32), 0) return inputs + def test_multi_images_as_input(self): + inputs = self.get_dummy_inputs() + image = inputs.pop("image") + inputs["image"] = [image, image] + + pipe = self.get_pipeline().to(torch_device) + _ = pipe( + **inputs, + ) + @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) def test_num_images_per_prompt(self): super().test_num_images_per_prompt() @@ -117,4 +128,4 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul super().test_inference_batch_single_identical() def test_guider_cfg(self): - super().test_guider_cfg(1e-3) + super().test_guider_cfg(1e-6) From be3c2a0667493022f17d756ca3dba631d28dfb40 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 10 Dec 2025 12:19:07 +0530 Subject: [PATCH 5/6] [WIP] Add Flux2 modular (#12763) * update * update * update * update * update * update * update * update * update * update --- src/diffusers/__init__.py | 4 + src/diffusers/modular_pipelines/__init__.py | 5 + .../modular_pipelines/flux2/__init__.py | 111 ++++ .../modular_pipelines/flux2/before_denoise.py | 508 ++++++++++++++++++ .../modular_pipelines/flux2/decoders.py | 146 +++++ .../modular_pipelines/flux2/denoise.py | 252 +++++++++ .../modular_pipelines/flux2/encoders.py | 349 ++++++++++++ .../modular_pipelines/flux2/inputs.py | 160 ++++++ .../modular_pipelines/flux2/modular_blocks.py | 166 ++++++ .../flux2/modular_pipeline.py | 57 ++ .../modular_pipelines/modular_pipeline.py | 2 +- .../dummy_torch_and_transformers_objects.py | 30 ++ tests/modular_pipelines/flux2/__init__.py | 0 .../flux2/test_modular_pipeline_flux2.py | 93 ++++ .../test_modular_pipelines_common.py | 1 - 15 files changed, 1882 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/modular_pipelines/flux2/__init__.py create mode 100644 src/diffusers/modular_pipelines/flux2/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/flux2/decoders.py create mode 100644 src/diffusers/modular_pipelines/flux2/denoise.py create mode 100644 src/diffusers/modular_pipelines/flux2/encoders.py create mode 100644 src/diffusers/modular_pipelines/flux2/inputs.py create mode 100644 src/diffusers/modular_pipelines/flux2/modular_blocks.py create mode 100644 src/diffusers/modular_pipelines/flux2/modular_pipeline.py create mode 100644 tests/modular_pipelines/flux2/__init__.py create mode 100644 tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e69d334fdb..fdd27f2464 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -404,6 +404,8 @@ except OptionalDependencyNotAvailable: else: _import_structure["modular_pipelines"].extend( [ + "Flux2AutoBlocks", + "Flux2ModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -1111,6 +1113,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( + Flux2AutoBlocks, + Flux2ModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index dea9da0269..5fcc1a176d 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -52,6 +52,10 @@ else: "FluxKontextAutoBlocks", "FluxKontextModularPipeline", ] + _import_structure["flux2"] = [ + "Flux2AutoBlocks", + "Flux2ModularPipeline", + ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", "QwenImageModularPipeline", @@ -75,6 +79,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline + from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py new file mode 100644 index 0000000000..21a41c1fe9 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -0,0 +1,111 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = [ + "Flux2TextEncoderStep", + "Flux2RemoteTextEncoderStep", + "Flux2VaeEncoderStep", + ] + _import_structure["before_denoise"] = [ + "Flux2SetTimestepsStep", + "Flux2PrepareLatentsStep", + "Flux2RoPEInputsStep", + "Flux2PrepareImageLatentsStep", + ] + _import_structure["denoise"] = [ + "Flux2LoopDenoiser", + "Flux2LoopAfterDenoiser", + "Flux2DenoiseLoopWrapper", + "Flux2DenoiseStep", + ] + _import_structure["decoders"] = ["Flux2DecodeStep"] + _import_structure["inputs"] = [ + "Flux2ProcessImagesInputStep", + "Flux2TextInputStep", + ] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "REMOTE_AUTO_BLOCKS", + "TEXT2IMAGE_BLOCKS", + "IMAGE_CONDITIONED_BLOCKS", + "Flux2AutoBlocks", + "Flux2AutoVaeEncoderStep", + "Flux2BeforeDenoiseStep", + "Flux2VaeEncoderSequentialStep", + ] + _import_structure["modular_pipeline"] = ["Flux2ModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, + ) + from .decoders import Flux2DecodeStep + from .denoise import ( + Flux2DenoiseLoopWrapper, + Flux2DenoiseStep, + Flux2LoopAfterDenoiser, + Flux2LoopDenoiser, + ) + from .encoders import ( + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, + Flux2VaeEncoderStep, + ) + from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, + ) + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + IMAGE_CONDITIONED_BLOCKS, + REMOTE_AUTO_BLOCKS, + TEXT2IMAGE_BLOCKS, + Flux2AutoBlocks, + Flux2AutoVaeEncoderStep, + Flux2BeforeDenoiseStep, + Flux2VaeEncoderSequentialStep, + ) + from .modular_pipeline import Flux2ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py new file mode 100644 index 0000000000..42624688ad --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -0,0 +1,508 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + """Compute empirical mu for Flux2 timestep scheduling.""" + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Flux2SetTimestepsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("guidance_scale", default=4.0), + InputParam("latents", type_hint=torch.Tensor), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + scheduler = components.scheduler + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + vae_scale_factor = components.vae_scale_factor + + latent_height = 2 * (int(height) // (vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + num_inference_steps = block_state.num_inference_steps + sigmas = block_state.sigmas + timesteps = block_state.timesteps + + if timesteps is None and sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + block_state.device, + timesteps=timesteps, + sigmas=sigmas, + mu=mu, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"), + ] + + @staticmethod + def check_inputs(components, block_state): + vae_scale_factor = components.vae_scale_factor + if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def _prepare_latent_ids(latents: torch.Tensor): + """ + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents: Latent tensor of shape (B, C, H, W) + + Returns: + Position IDs tensor of shape (B, H*W, 4) + """ + batch_size, _, height, width = latents.shape + + t = torch.arange(1) + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) + + latent_ids = torch.cartesian_prod(t, h, w, l) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(block_state.device) + + latents = self._pack_latents(latents) + + block_state.latents = latents + block_state.latent_ids = latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + InputParam(name="latent_ids"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + OutputParam( + name="latent_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareImageLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares image latents and their position IDs for Flux2 image conditioning." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image_latents", type_hint=List[torch.Tensor]), + InputParam("batch_size", required=True, type_hint=int), + InputParam("num_images_per_prompt", default=1, type_hint=int), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning", + ), + OutputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents", + ), + ] + + @staticmethod + def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10): + """ + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + Args: + image_latents: A list of image latent feature tensors of shape (1, C, H, W). + scale: Factor used to define the time separation between latents. + + Returns: + Combined coordinate tensor of shape (1, N_total, 4) + """ + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image_latents = block_state.image_latents + + if image_latents is None: + block_state.image_latents = None + block_state.image_latent_ids = None + self.set_block_state(state, block_state) + + return components, state + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + image_latent_ids = self._prepare_image_ids(image_latents) + + packed_latents = [] + for latent in image_latents: + packed = self._pack_latents(latent) + packed = packed.squeeze(0) + packed_latents.append(packed) + + image_latents = torch.cat(packed_latents, dim=0) + image_latents = image_latents.unsqueeze(0) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + block_state.image_latents = image_latents + block_state.image_latent_ids = image_latent_ids + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py new file mode 100644 index 0000000000..b769d91198 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -0,0 +1,146 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="Position IDs for the latents, used for unpacking", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor: + """ + Unpack latents using position IDs to scatter tokens into place. + + Args: + x: Packed latents tensor of shape (B, seq_len, C) + x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates + + Returns: + Unpacked latents tensor of shape (B, C, H, W) + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + @staticmethod + def _unpatchify_latents(latents): + """Convert patchified latents back to regular format.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if block_state.output_type == "latent": + block_state.images = block_state.latents + else: + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + + latents = self._unpatchify_latents(latents) + + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py new file mode 100644 index 0000000000..c12eca65c6 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -0,0 +1,252 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple + +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2LoopDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +class Flux2LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that updates the latents after denoising. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> List[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attribute" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process.", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2LoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py new file mode 100644 index 0000000000..6cb0e3bf0a --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -0,0 +1,349 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from ...models import AutoencoderKLFlux2 +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def format_text_input(prompts: List[str], system_message: str = None): + """Format prompts for Mistral3 chat template.""" + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2TextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + # fmt: off + DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Mistral3ForConditionalGeneration), + ComponentSpec("tokenizer", AutoProcessor), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from Mistral3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + prompt_embeds = getattr(block_state, "prompt_embeds", None) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_mistral_3_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: Tuple[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + if block_state.prompt_embeds is not None: + self.set_block_state(state, block_state) + return components, state + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_mistral_3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + system_message=self.DEFAULT_SYSTEM_MESSAGE, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using a remote API endpoint" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from remote API used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + prompt_embeds = getattr(block_state, "prompt_embeds", None) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + import io + + import requests + from huggingface_hub import get_token + + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + if block_state.prompt_embeds is not None: + self.set_block_state(state, block_state) + return components, state + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + response = requests.post( + self.REMOTE_URL, + json={"prompt": prompt}, + headers={ + "Authorization": f"Bearer {get_token()}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + + block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True) + block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2VaeEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLFlux2)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("condition_images", type_hint=List[torch.Tensor]), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=List[torch.Tensor], + description="List of latent representations for each reference image", + ), + ] + + @staticmethod + def _patchify_latents(latents): + """Convert latents to patchified format for Flux2.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator): + """Encode a single image using Flux2 VAE with batch norm normalization.""" + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) + latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + condition_images = block_state.condition_images + + if condition_images is None: + return components, state + + device = components._execution_device + dtype = components.vae.dtype + + image_latents = [] + for image in condition_images: + image = image.to(device=device, dtype=dtype) + latent = self._encode_vae_image( + vae=components.vae, + image=image, + generator=block_state.generator, + ) + image_latents.append(latent) + + block_state.image_latents = image_latents + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py new file mode 100644 index 0000000000..c9e337fb0b --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -0,0 +1,160 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + +from ...configuration_utils import FrozenDict +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) + + +class Flux2TextInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Image preprocess step for Flux2. Validates and preprocesses reference images." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.condition_images = None + self.set_block_state(state, block_state) + return components, state + + if not isinstance(images, list): + images = [images] + + condition_images = [] + for img in images: + components.image_processor.check_image_input(img) + + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = components.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = components.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + condition_img = components.image_processor.preprocess( + img, height=image_height, width=image_width, resize_mode="crop" + ) + condition_images.append(condition_img) + + if block_state.height is None: + block_state.height = image_height + if block_state.width is None: + block_state.width = image_width + + block_state.condition_images = condition_images + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks.py b/src/diffusers/modular_pipelines/flux2/modular_blocks.py new file mode 100644 index 0000000000..a31673b6e7 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks.py @@ -0,0 +1,166 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep +from .denoise import Flux2DenoiseStep +from .encoders import ( + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +Flux2VaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ] +) + + +class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2VaeEncoderBlocks.values() + block_names = Flux2VaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning." + + +class Flux2AutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2VaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +Flux2BeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ] +) + + +class Flux2BeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2BeforeDenoiseBlocks.values() + block_names = Flux2BeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation." + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("vae_image_encoder", Flux2AutoVaeEncoderStep()), + ("before_denoise", Flux2BeforeDenoiseStep()), + ("denoise", Flux2DenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +REMOTE_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2RemoteTextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("vae_image_encoder", Flux2AutoVaeEncoderStep()), + ("before_denoise", Flux2BeforeDenoiseStep()), + ("denoise", Flux2DenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +class Flux2AutoBlocks(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n" + "- For text-to-image generation, all you need to provide is `prompt`.\n" + "- For image-conditioned generation, you need to provide `image` (list of PIL images)." + ) + + +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + +IMAGE_CONDITIONED_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("preprocess_images", Flux2ProcessImagesInputStep()), + ("vae_encoder", Flux2VaeEncoderStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "image_conditioned": IMAGE_CONDITIONED_BLOCKS, + "auto": AUTO_BLOCKS, + "remote": REMOTE_AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py new file mode 100644 index 0000000000..3e497f3b1e --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -0,0 +1,57 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import Flux2LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index bba89e6121..17c0117bff 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -58,6 +58,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( ("wan", "WanModularPipeline"), ("flux", "FluxModularPipeline"), ("flux-kontext", "FluxKontextModularPipeline"), + ("flux2", "Flux2ModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), @@ -1586,7 +1587,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) @property diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da64742518..ff65372f3c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,36 @@ from ..utils import DummyObject, requires_backends +class Flux2AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/flux2/__init__.py b/tests/modular_pipelines/flux2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py new file mode 100644 index 0000000000..8fd529e97e --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2AutoBlocks, + Flux2ModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2ModularPipeline + pipeline_blocks_class = Flux2AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2ModularPipeline + pipeline_blocks_class = Flux2AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index a33951dac5..661fcc2537 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -165,7 +165,6 @@ class ModularPipelineTesterMixin: expected_max_diff=1e-4, ): pipe = self.get_pipeline().to(torch_device) - inputs = self.get_dummy_inputs() # Reset generator in case it is has been used in self.get_dummy_inputs From 6708f5c76d50be208b8043c58e142d6551e4fba5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 11 Dec 2025 00:25:07 +0800 Subject: [PATCH 6/6] [docs] improve distributed inference cp docs. (#12810) * improve distributed inference cp docs. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/training/distributed_inference.md | 85 ++++++++++++++----- 1 file changed, 64 insertions(+), 21 deletions(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index f9756e1a67..534124cb93 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends. +Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible. + ### Ring Attention Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency. @@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf ```py import torch -from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig +from torch import distributed as dist +from diffusers import DiffusionPipeline, ContextParallelConfig -try: - torch.distributed.init_process_group("nccl") - rank = torch.distributed.get_rank() - device = torch.device("cuda", rank % torch.cuda.device_count()) +def setup_distributed(): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - - transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2)) - pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda") - pipeline.transformer.set_attention_backend("flash") + return device + +def main(): + device = setup_distributed() + world_size = dist.get_world_size() + + pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device + ) + pipeline.transformer.set_attention_backend("_native_cudnn") + + cp_config = ContextParallelConfig(ring_degree=world_size) + pipeline.transformer.enable_parallelism(config=cp_config) prompt = """ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain """ - + # Must specify generator so all ranks start with same latents (or pass your own) generator = torch.Generator().manual_seed(42) - image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0] - - if rank == 0: - image.save("output.png") + image = pipeline( + prompt, + guidance_scale=3.5, + num_inference_steps=50, + generator=generator, + ).images[0] -except Exception as e: - print(f"An error occurred: {e}") - torch.distributed.breakpoint() - raise + if dist.get_rank() == 0: + image.save(f"output.png") -finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() ``` +The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available. + +/```shell +`torchrun --nproc-per-node 2 above_script.py`. +/``` + ### Ulysses Attention [Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer. @@ -288,5 +310,26 @@ finally: Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`]. ```py +# Depending on the number of GPUs available. pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2)) +``` + +### parallel_config + +Pass `parallel_config` during model initialization to enable context parallelism. + +```py +CKPT_ID = "black-forest-labs/FLUX.1-dev" + +cp_config = ContextParallelConfig(ring_degree=2) +transformer = AutoModel.from_pretrained( + CKPT_ID, + subfolder="transformer", + torch_dtype=torch.bfloat16, + parallel_config=cp_config +) + +pipeline = DiffusionPipeline.from_pretrained( + CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, +).to(device) ``` \ No newline at end of file