1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

a few fix: unpack latents before decoder etc

This commit is contained in:
yiyixuxu
2026-01-20 23:13:53 +01:00
parent d2953678ad
commit c10041e57e
5 changed files with 233 additions and 54 deletions

View File

@@ -29,29 +29,16 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2DecodeStep(ModularPipelineBlocks):
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
return "Step that unpacks the latents from the denoising step"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
@@ -70,9 +57,9 @@ class Flux2DecodeStep(ModularPipelineBlocks):
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
"latents",
type_hint=torch.Tensor,
description="The denoise latents from denoising step, unpacked with position IDs.",
)
]
@@ -107,6 +94,64 @@ class Flux2DecodeStep(ModularPipelineBlocks):
return torch.stack(x_list, dim=0)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = self._unpack_latents_with_ids(latents, latent_ids)
block_state.latents = latents
self.set_block_state(state, block_state)
return components, state
class Flux2DecodeStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
def _unpatchify_latents(latents):
"""Convert patchified latents back to regular format."""
@@ -121,26 +166,20 @@ class Flux2DecodeStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
vae = components.vae
if block_state.output_type == "latent":
block_state.images = block_state.latents
else:
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = block_state.latents
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
latents = self._unpatchify_latents(latents)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -306,10 +306,6 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
]
@torch.no_grad()
@@ -339,20 +335,12 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
additional_cond_kwargs = {}
for field_name, field_value in block_state.denoiser_input_fields.items():
if field_name in transformer_args and field_name not in guider_inputs:
additional_cond_kwargs[field_name] = field_value
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs.update(additional_cond_kwargs)
noise_pred = components.transformer(
hidden_states=latent_model_input,
@@ -458,8 +446,6 @@ class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
block_state.additional_cond_kwargs = {}
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)

View File

@@ -17,6 +17,9 @@ from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
from ...guiders import ClassifierFreeGuidance
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
@@ -259,6 +262,141 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
ComponentSpec("tokenizer", Qwen2TokenizerFast),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=True),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
self.set_block_state(state, block_state)
return components, state
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [

View File

@@ -21,7 +21,7 @@ from .before_denoise import (
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2RemoteTextEncoderStep,
@@ -99,6 +99,7 @@ AUTO_BLOCKS = InsertableDict(
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -111,6 +112,7 @@ REMOTE_AUTO_BLOCKS = InsertableDict(
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -139,6 +141,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -154,6 +157,7 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)

View File

@@ -21,10 +21,11 @@ from .before_denoise import (
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinTextEncoderStep,
Flux2KleinBaseTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
@@ -35,7 +36,9 @@ from .inputs import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
###
### VAE encoder
###
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
@@ -69,6 +72,9 @@ class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
" - If `image` is not provided, step will be skipped."
)
###
### Core denoise
###
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
@@ -78,6 +84,7 @@ Flux2KleinCoreDenoiseBlocks = InsertableDict(
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
@@ -99,6 +106,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@@ -110,6 +118,7 @@ Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
@@ -130,9 +139,12 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
###
### Auto blocks
###
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
@@ -155,7 +167,7 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinBaseTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),