From c10041e57e1f9c3cf2d3ff96fd535e25dfa4f150 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 Jan 2026 23:13:53 +0100 Subject: [PATCH] a few fix: unpack latents before decoder etc --- .../modular_pipelines/flux2/decoders.py | 109 +++++++++----- .../modular_pipelines/flux2/denoise.py | 14 -- .../modular_pipelines/flux2/encoders.py | 138 ++++++++++++++++++ .../flux2/modular_blocks_flux2.py | 6 +- .../flux2/modular_blocks_flux2_klein.py | 20 ++- 5 files changed, 233 insertions(+), 54 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py index b769d91198..e881367208 100644 --- a/src/diffusers/modular_pipelines/flux2/decoders.py +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -29,29 +29,16 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class Flux2DecodeStep(ModularPipelineBlocks): +class Flux2UnpackLatentsStep(ModularPipelineBlocks): model_name = "flux2" - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKLFlux2), - ComponentSpec( - "image_processor", - Flux2ImageProcessor, - config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), - default_creation_method="from_config", - ), - ] - @property def description(self) -> str: - return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + return "Step that unpacks the latents from the denoising step" @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("output_type", default="pil"), InputParam( "latents", required=True, @@ -70,9 +57,9 @@ class Flux2DecodeStep(ModularPipelineBlocks): def intermediate_outputs(self) -> List[str]: return [ OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], - description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + "latents", + type_hint=torch.Tensor, + description="The denoise latents from denoising step, unpacked with position IDs.", ) ] @@ -107,6 +94,64 @@ class Flux2DecodeStep(ModularPipelineBlocks): return torch.stack(x_list, dim=0) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + block_state.latents = latents + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + @staticmethod def _unpatchify_latents(latents): """Convert patchified latents back to regular format.""" @@ -121,26 +166,20 @@ class Flux2DecodeStep(ModularPipelineBlocks): block_state = self.get_block_state(state) vae = components.vae - if block_state.output_type == "latent": - block_state.images = block_state.latents - else: - latents = block_state.latents - latent_ids = block_state.latent_ids + latents = block_state.latents - latents = self._unpack_latents_with_ids(latents, latent_ids) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean - latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) - latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( - latents.device, latents.dtype - ) - latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) - latents = self._unpatchify_latents(latents) - - block_state.images = vae.decode(latents, return_dict=False)[0] - block_state.images = components.image_processor.postprocess( - block_state.images, output_type=block_state.output_type - ) + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index 3dd5661d49..a30382b5f7 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -306,10 +306,6 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): type_hint=torch.Tensor, description="4D position IDs for latent tokens (T, H, W, L)", ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), ] @torch.no_grad() @@ -339,20 +335,12 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): ), } - transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) - additional_cond_kwargs = {} - for field_name, field_value in block_state.denoiser_input_fields.items(): - if field_name in transformer_args and field_name not in guider_inputs: - additional_cond_kwargs[field_name] = field_value - block_state.additional_cond_kwargs.update(additional_cond_kwargs) - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} - cond_kwargs.update(additional_cond_kwargs) noise_pred = components.transformer( hidden_states=latent_model_input, @@ -458,8 +446,6 @@ class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 ) - block_state.additional_cond_kwargs = {} - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 835feb86cc..b2a93e0a25 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -17,6 +17,9 @@ from typing import List, Optional, Tuple, Union import torch from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM +from ...guiders import ClassifierFreeGuidance +from ...configuration_utils import FrozenDict + from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState @@ -259,6 +262,141 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks): ComponentSpec("tokenizer", Qwen2TokenizerFast), ] + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=True), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + @property def expected_configs(self) -> List[ConfigSpec]: return [ diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index a31673b6e7..bad167f842 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -21,7 +21,7 @@ from .before_denoise import ( Flux2RoPEInputsStep, Flux2SetTimestepsStep, ) -from .decoders import Flux2DecodeStep +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep from .denoise import Flux2DenoiseStep from .encoders import ( Flux2RemoteTextEncoderStep, @@ -99,6 +99,7 @@ AUTO_BLOCKS = InsertableDict( ("vae_image_encoder", Flux2AutoVaeEncoderStep()), ("before_denoise", Flux2BeforeDenoiseStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -111,6 +112,7 @@ REMOTE_AUTO_BLOCKS = InsertableDict( ("vae_image_encoder", Flux2AutoVaeEncoderStep()), ("before_denoise", Flux2BeforeDenoiseStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -139,6 +141,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict( ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -154,6 +157,7 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict( ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index fc787ad1d2..22949c99d7 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -21,10 +21,11 @@ from .before_denoise import ( Flux2RoPEInputsStep, Flux2SetTimestepsStep, ) -from .decoders import Flux2DecodeStep +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep from .encoders import ( Flux2KleinTextEncoderStep, + Flux2KleinBaseTextEncoderStep, Flux2VaeEncoderStep, ) from .inputs import ( @@ -35,7 +36,9 @@ from .inputs import ( logger = logging.get_logger(__name__) # pylint: disable=invalid-name - +### +### VAE encoder +### Flux2KleinVaeEncoderBlocks = InsertableDict( [ ("preprocess", Flux2ProcessImagesInputStep()), @@ -69,6 +72,9 @@ class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): " - If `image` is not provided, step will be skipped." ) +### +### Core denoise +### Flux2KleinCoreDenoiseBlocks = InsertableDict( [ @@ -78,6 +84,7 @@ Flux2KleinCoreDenoiseBlocks = InsertableDict( ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) @@ -99,6 +106,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" " - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) @@ -110,6 +118,7 @@ Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) @@ -130,9 +139,12 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) - +### +### Auto blocks +### class Flux2KleinAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" block_classes = [ @@ -155,7 +167,7 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks): class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" block_classes = [ - Flux2KleinTextEncoderStep(), + Flux2KleinBaseTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep(),