From 5c1fc4489f95b162a2f8b5d69fc89e9bfa40c8f5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 00:59:56 +0100 Subject: [PATCH] remove guidannce to its own block --- .../modular_pipelines/flux2/before_denoise.py | 111 +++++++++++++++--- .../modular_pipelines/flux2/inputs.py | 62 ++++++++++ .../flux2/modular_blocks_flux2.py | 4 + .../flux2/modular_blocks_flux2_klein.py | 56 ++++++++- 4 files changed, 211 insertions(+), 22 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py index e1001924c7..d5bab16586 100644 --- a/src/diffusers/modular_pipelines/flux2/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks): InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas"), - InputParam("guidance_scale", default=4.0), InputParam("latents", type_hint=torch.Tensor), - InputParam("num_images_per_prompt", default=1), InputParam("height", type_hint=int), InputParam("width", type_hint=int), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", - ), ] @property @@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks): type_hint=int, description="The number of denoising steps to perform at inference time", ), - OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), ] @torch.no_grad() def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device scheduler = components.scheduler @@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks): timesteps, num_inference_steps = retrieve_timesteps( scheduler, num_inference_steps, - block_state.device, + device, timesteps=timesteps, sigmas=sigmas, mu=mu, @@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks): block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps - batch_size = block_state.batch_size * block_state.num_images_per_prompt - guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) - guidance = guidance.expand(batch_size) - block_state.guidance = guidance - components.scheduler.set_begin_index(0) self.set_block_state(state, block_state) @@ -349,6 +335,60 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks): def description(self) -> str: return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." + @property def inputs(self) -> List[InputParam]: return [ @@ -511,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks): self.set_block_state(state, block_state) return components, state + + +class Flux2PrepareGuidanceStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the guidance scale tensor for Flux2 inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("guidance_scale", default=4.0), + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index cc078c8262..3463de1999 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -30,6 +30,68 @@ logger = logging.get_logger(__name__) class Flux2TextInputStep(ModularPipelineBlocks): model_name = "flux2" + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextInputStep(ModularPipelineBlocks): + model_name = "flux2-klein" + @property def description(self) -> str: return ( diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index bad167f842..af6c1819ec 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -19,6 +19,7 @@ from .before_denoise import ( Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, + Flux2PrepareGuidanceStep, Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep @@ -76,6 +77,7 @@ Flux2BeforeDenoiseBlocks = InsertableDict( [ ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ] ) @@ -139,6 +141,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict( ("text_input", Flux2TextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), ("after_denoise", Flux2UnpackLatentsStep()), @@ -155,6 +158,7 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict( ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), ("after_denoise", Flux2UnpackLatentsStep()), diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index b681238628..6e1cb985e7 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch +import PIL.Image +from typing import List from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict +from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, + Flux2KleinBaseRoPEInputsStep, Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep @@ -31,6 +35,7 @@ from .encoders import ( from .inputs import ( Flux2ProcessImagesInputStep, Flux2TextInputStep, + Flux2KleinBaseTextInputStep, ) @@ -101,7 +106,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." return ( "Core denoise step that performs the denoising process for Flux2-Klein.\n" - " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" @@ -110,14 +115,24 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( [ - ("input", Flux2TextInputStep()), + ("input", Flux2KleinBaseTextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), - ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), ("denoise", Flux2KleinBaseDenoiseStep()), ("after_denoise", Flux2UnpackLatentsStep()), ] @@ -134,14 +149,23 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." return ( "Core denoise step that performs the denoising process for Flux2-Klein (base model).\n" - " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n" " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" - " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n" " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] ### @@ -165,6 +189,16 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + " - for text-to-image generation, all you need to provide is `prompt`.\n" ) + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" @@ -183,3 +217,13 @@ class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + " - for text-to-image generation, all you need to provide is `prompt`.\n" ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ]