From 1c500c8eeb27ec97ef1bcfe20ab33159c4490e2b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 08:06:32 +0100 Subject: [PATCH] flux2-dev work in modular setting --- .../flux2/modular_blocks_flux2.py | 65 ++++++++++++++----- .../flux2/modular_blocks_flux2_klein.py | 10 +-- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index 66509454c3..eba2cbbd00 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import PIL.Image +from typing import List +import torch + from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict +from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( Flux2PrepareGuidanceStep, Flux2PrepareImageLatentsStep, @@ -42,7 +46,6 @@ Flux2VaeEncoderBlocks = InsertableDict( [ ("preprocess", Flux2ProcessImagesInputStep()), ("encode", Flux2VaeEncoderStep()), - ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ] ) @@ -73,35 +76,56 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks): ) -Flux2BeforeDenoiseBlocks = InsertableDict( +Flux2CoreDenoiseBlocks = InsertableDict( [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) -class Flux2BeforeDenoiseStep(SequentialPipelineBlocks): +class Flux2CoreDenoiseStep(SequentialPipelineBlocks): model_name = "flux2" - block_classes = Flux2BeforeDenoiseBlocks.values() - block_names = Flux2BeforeDenoiseBlocks.keys() + block_classes = Flux2CoreDenoiseBlocks.values() + block_names = Flux2CoreDenoiseBlocks.keys() @property def description(self): - return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation." + return ( + "Core denoise step that performs the denoising process for Flux2-dev.\n" + " - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] AUTO_BLOCKS = InsertableDict( [ ("text_encoder", Flux2TextEncoderStep()), - ("text_input", Flux2TextInputStep()), - ("vae_image_encoder", Flux2AutoVaeEncoderStep()), - ("before_denoise", Flux2BeforeDenoiseStep()), - ("denoise", Flux2DenoiseStep()), - ("after_denoise", Flux2UnpackLatentsStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), ("decode", Flux2DecodeStep()), ] ) @@ -110,11 +134,8 @@ AUTO_BLOCKS = InsertableDict( REMOTE_AUTO_BLOCKS = InsertableDict( [ ("text_encoder", Flux2RemoteTextEncoderStep()), - ("text_input", Flux2TextInputStep()), - ("vae_image_encoder", Flux2AutoVaeEncoderStep()), - ("before_denoise", Flux2BeforeDenoiseStep()), - ("denoise", Flux2DenoiseStep()), - ("after_denoise", Flux2UnpackLatentsStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), ("decode", Flux2DecodeStep()), ] ) @@ -134,6 +155,16 @@ class Flux2AutoBlocks(SequentialPipelineBlocks): "- For image-conditioned generation, you need to provide `image` (list of PIL images)." ) + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + TEXT2IMAGE_BLOCKS = InsertableDict( [ diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 0ecbbceb6d..984832d77b 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -43,9 +43,10 @@ from .inputs import ( logger = logging.get_logger(__name__) # pylint: disable=invalid-name -### -### VAE encoder -### +################ +# VAE encoder +################ + Flux2KleinVaeEncoderBlocks = InsertableDict( [ ("preprocess", Flux2ProcessImagesInputStep()), @@ -105,9 +106,8 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): @property def description(self): - return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." return ( - "Core denoise step that performs the denoising process for Flux2-Klein.\n" + "Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n" " - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"