1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into apply-lora-scale-decorator

This commit is contained in:
Sayak Paul
2026-01-28 09:30:36 +05:30
committed by GitHub
33 changed files with 2099 additions and 273 deletions

View File

@@ -413,6 +413,9 @@ else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
@@ -1146,6 +1149,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,

View File

@@ -143,41 +143,86 @@ class GlmImageAdaLayerNormZero(nn.Module):
class GlmImageLayerKVCache:
"""KV cache for GlmImage model."""
"""KV cache for GlmImage model.
Supports per-sample caching for batch processing where each sample may have different condition images.
"""
def __init__(self):
self.k_cache = None
self.v_cache = None
self.k_caches: List[Optional[torch.Tensor]] = []
self.v_caches: List[Optional[torch.Tensor]] = []
self.mode: Optional[str] = None # "write", "read", "skip"
self.current_sample_idx: int = 0 # Current sample index for writing
def store(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None:
self.k_cache = k
self.v_cache = v
"""Store KV cache for the current sample."""
# k, v shape: (1, seq_len, num_heads, head_dim)
if len(self.k_caches) <= self.current_sample_idx:
# First time storing for this sample
self.k_caches.append(k)
self.v_caches.append(v)
else:
self.k_cache = torch.cat([self.k_cache, k], dim=1)
self.v_cache = torch.cat([self.v_cache, v], dim=1)
# Append to existing cache for this sample (multiple condition images)
self.k_caches[self.current_sample_idx] = torch.cat([self.k_caches[self.current_sample_idx], k], dim=1)
self.v_caches[self.current_sample_idx] = torch.cat([self.v_caches[self.current_sample_idx], v], dim=1)
def get(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache.shape[0] != k.shape[0]:
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
else:
k_cache_expanded = self.k_cache
v_cache_expanded = self.v_cache
"""Get combined KV cache for all samples in the batch.
k_cache = torch.cat([k_cache_expanded, k], dim=1)
v_cache = torch.cat([v_cache_expanded, v], dim=1)
return k_cache, v_cache
Args:
k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim)
v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim)
Returns:
Combined key and value tensors with cached values prepended.
"""
batch_size = k.shape[0]
num_cached_samples = len(self.k_caches)
if num_cached_samples == 0:
return k, v
if num_cached_samples == 1:
# Single cache, expand for all batch samples (shared condition images)
k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1)
v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1)
elif num_cached_samples == batch_size:
# Per-sample cache, concatenate along batch dimension
k_cache_expanded = torch.cat(self.k_caches, dim=0)
v_cache_expanded = torch.cat(self.v_caches, dim=0)
else:
# Mismatch: try to handle by repeating the caches
# This handles cases like num_images_per_prompt > 1
repeat_factor = batch_size // num_cached_samples
if batch_size % num_cached_samples == 0:
k_cache_list = []
v_cache_list = []
for i in range(num_cached_samples):
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
k_cache_expanded = torch.cat(k_cache_list, dim=0)
v_cache_expanded = torch.cat(v_cache_list, dim=0)
else:
raise ValueError(
f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. "
f"Batch size must be a multiple of the number of cached samples."
)
k_combined = torch.cat([k_cache_expanded, k], dim=1)
v_combined = torch.cat([v_cache_expanded, v], dim=1)
return k_combined, v_combined
def clear(self):
self.k_cache = None
self.v_cache = None
self.k_caches = []
self.v_caches = []
self.mode = None
self.current_sample_idx = 0
def next_sample(self):
"""Move to the next sample for writing."""
self.current_sample_idx += 1
class GlmImageKVCache:
"""Container for all layers' KV caches."""
"""Container for all layers' KV caches.
Supports per-sample caching for batch processing where each sample may have different condition images.
"""
def __init__(self, num_layers: int):
self.num_layers = num_layers
@@ -192,6 +237,12 @@ class GlmImageKVCache:
for cache in self.caches:
cache.mode = mode
def next_sample(self):
"""Move to the next sample for writing. Call this after processing
all condition images for one batch sample."""
for cache in self.caches:
cache.next_sample()
def clear(self):
for cache in self.caches:
cache.clear()

View File

@@ -54,7 +54,10 @@ else:
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -81,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .flux2 import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -43,7 +43,7 @@ else:
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks"] = [
_import_structure["modular_blocks_flux2"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
@@ -51,10 +51,11 @@ else:
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBlocks",
"Flux2AutoVaeEncoderStep",
"Flux2BeforeDenoiseStep",
"Flux2CoreDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -85,7 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
from .modular_blocks_flux2 import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
@@ -93,10 +94,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
TEXT2IMAGE_BLOCKS,
Flux2AutoBlocks,
Flux2AutoVaeEncoderStep,
Flux2BeforeDenoiseStep,
Flux2CoreDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
from .modular_blocks_flux2_klein import (
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
@@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
scheduler = components.scheduler
@@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
@@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
@@ -353,7 +339,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
]
@property
@@ -365,12 +350,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
),
]
@staticmethod
@@ -403,6 +382,72 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
return components, state
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="negative_prompt_embeds", required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="negative_txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
),
]
@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)
return torch.stack(out_ids)
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
block_state.negative_txt_ids = None
if block_state.negative_prompt_embeds is not None:
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@@ -506,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares the guidance scale tensor for Flux2 inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("guidance_scale", default=4.0),
InputParam("num_images_per_prompt", default=1),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state

View File

@@ -29,29 +29,16 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2DecodeStep(ModularPipelineBlocks):
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
return "Step that unpacks the latents from the denoising step"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
@@ -70,9 +57,9 @@ class Flux2DecodeStep(ModularPipelineBlocks):
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
"latents",
type_hint=torch.Tensor,
description="The denoise latents from denoising step, unpacked with position IDs.",
)
]
@@ -107,6 +94,62 @@ class Flux2DecodeStep(ModularPipelineBlocks):
return torch.stack(x_list, dim=0)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = self._unpack_latents_with_ids(latents, latent_ids)
block_state.latents = latents
self.set_block_state(state, block_state)
return components, state
class Flux2DecodeStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
def _unpatchify_latents(latents):
"""Convert patchified latents back to regular format."""
@@ -121,26 +164,20 @@ class Flux2DecodeStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
vae = components.vae
if block_state.output_type == "latent":
block_state.images = block_state.latents
else:
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = block_state.latents
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
latents = self._unpatchify_latents(latents)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -16,6 +16,8 @@ from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
@@ -25,8 +27,8 @@ from ..modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
if is_torch_xla_available():
@@ -134,6 +136,229 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
return components, block_state
# same as Flux2LoopDenoiser but guidance=None
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
# support CFG for Flux2-Klein base model
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", Flux2Transformer2DModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"negative_prompt_embeds",
required=False,
type_hint=torch.Tensor,
description="Negative text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"negative_txt_ids",
required=False,
type_hint=torch.Tensor,
description="4D position IDs for negative text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"txt_ids": (
getattr(block_state, "txt_ids", None),
getattr(block_state, "negative_txt_ids", None),
),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
components.guider.cleanup_models(components.transformer)
# perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@@ -250,3 +475,35 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinBaseLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -15,13 +15,15 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -79,10 +81,8 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
@@ -99,14 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
@@ -165,10 +158,6 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -205,7 +194,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
@@ -222,15 +210,8 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
@@ -244,10 +225,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -270,6 +247,289 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
return components, state
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=True),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
self.set_block_state(state, block_state)
return components, state
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Negative text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
if components.requires_unconditional_embeds:
negative_prompt = [""] * len(prompt)
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
else:
block_state.negative_prompt_embeds = None
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"

View File

@@ -47,7 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
]
@@ -89,6 +89,90 @@ class Flux2TextInputStep(ModularPipelineBlocks):
return components, state
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return (
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
required=False,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Negative text embeddings used to guide the image generation",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux2"

View File

@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2PrepareGuidanceStep,
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2RemoteTextEncoderStep,
@@ -41,7 +47,6 @@ Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
]
)
@@ -72,33 +77,56 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
)
Flux2BeforeDenoiseBlocks = InsertableDict(
Flux2CoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2BeforeDenoiseBlocks.values()
block_names = Flux2BeforeDenoiseBlocks.keys()
block_classes = Flux2CoreDenoiseBlocks.values()
block_names = Flux2CoreDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
return (
"Core denoise step that performs the denoising process for Flux2-dev.\n"
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -107,10 +135,8 @@ AUTO_BLOCKS = InsertableDict(
REMOTE_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2RemoteTextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -130,6 +156,16 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
TEXT2IMAGE_BLOCKS = InsertableDict(
[
@@ -137,8 +173,10 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -152,8 +190,10 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)

View File

@@ -0,0 +1,232 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2KleinBaseRoPEInputsStep,
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinBaseTextEncoderStep,
Flux2KleinTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2KleinBaseTextInputStep,
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
################
# VAE encoder
################
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
]
)
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2KleinVaeEncoderBlocks.values()
block_names = Flux2KleinVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2KleinVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
###
### Core denoise
###
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return (
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2KleinBaseTextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
###
### Auto blocks
###
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinBaseTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -55,3 +57,56 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2-Klein.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds

View File

@@ -59,6 +59,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),

View File

@@ -52,6 +52,15 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_NEGATIVE_PROMPT = (
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
@@ -359,7 +368,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
@@ -549,6 +558,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
num_latent_conditional_frames: int = 2,
):
r"""
The call function to the pipeline for generation. Supports three modes:
@@ -614,6 +624,10 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
num_latent_conditional_frames (`int`, defaults to `2`):
Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames
extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1
for Image2World-like behavior (single frame conditioning).
Examples:
@@ -692,19 +706,38 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
if num_latent_conditional_frames not in [1, 2]:
raise ValueError(
f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}"
)
frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1
total_input_frames = len(video)
if total_input_frames < frames_to_extract:
raise ValueError(
f"Input video has only {total_input_frames} frames but Video2World requires at least "
f"{frames_to_extract} frames for conditioning."
)
num_frames_in = frames_to_extract
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
# For Video2World: extract last frames_to_extract frames from input, then pad
if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]:
video = video[:, :, -num_frames_in:, :, :]
num_frames_out = num_frames
if video.shape[2] < num_frames_out:
n_pad_frames = num_frames_out - num_frames_in
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
n_pad_frames = num_frames_out - video.shape[2]
last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W]
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
video = torch.cat((video, pad_frames), dim=2)

View File

@@ -49,6 +49,14 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_NEGATIVE_PROMPT = (
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
)
EXAMPLE_DOC_STRING = """
Examples:
@@ -300,7 +308,7 @@ class Cosmos2TextToImagePipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):

View File

@@ -50,6 +50,14 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_NEGATIVE_PROMPT = (
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
)
EXAMPLE_DOC_STRING = """
Examples:
@@ -319,7 +327,7 @@ class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):

View File

@@ -49,6 +49,14 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_NEGATIVE_PROMPT = (
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
)
EXAMPLE_DOC_STRING = """
Examples:
@@ -285,7 +293,7 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):

View File

@@ -50,6 +50,14 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_NEGATIVE_PROMPT = (
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
)
EXAMPLE_DOC_STRING = """
Examples:
@@ -331,7 +339,7 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):

View File

@@ -260,25 +260,115 @@ class GlmImagePipeline(DiffusionPipeline):
token_ids = token_ids.view(1, -1)
return token_ids
@staticmethod
def _validate_and_normalize_images(
image: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]],
batch_size: int,
) -> Optional[List[List[PIL.Image.Image]]]:
"""
Validate and normalize image inputs to List[List[PIL.Image]].
Rules:
- batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length
- batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]])
- Other formats raise ValueError
Args:
image: Input images in various formats
batch_size: Number of prompts in the batch
Returns:
Normalized images as List[List[PIL.Image]], or None if no images provided
"""
if image is None or len(image) == 0:
return None
first_element = image[0]
if batch_size == 1:
# Legacy format: List[PIL.Image] -> [[img1, img2, ...]]
if not isinstance(first_element, (list, tuple)):
return [list(image)]
# Already in List[List[PIL.Image]] format
if len(image) != 1:
raise ValueError(
f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}."
)
return [list(image[0])]
# batch_size > 1: must be List[List[PIL.Image]]
if not isinstance(first_element, (list, tuple)):
raise ValueError(
f"For batch_size > 1, images must be List[List[PIL.Image]] format. "
f"Got List[{type(first_element).__name__}] instead. "
f"Each prompt requires its own list of condition images."
)
if len(image) != batch_size:
raise ValueError(f"Number of image lists ({len(image)}) must match batch size ({batch_size}).")
# Validate homogeneous: all sublists must have same length
num_input_images_per_prompt = len(image[0])
for idx, imgs in enumerate(image):
if len(imgs) != num_input_images_per_prompt:
raise ValueError(
f"All prompts must have the same number of condition images. "
f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images."
)
return [list(imgs) for imgs in image]
def generate_prior_tokens(
self,
prompt: str,
prompt: Union[str, List[str]],
height: int,
width: int,
image: Optional[List[PIL.Image.Image]] = None,
image: Optional[List[List[PIL.Image.Image]]] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
):
"""
Generate prior tokens for the DiT model using the AR model.
Args:
prompt: Single prompt or list of prompts
height: Target image height
width: Target image width
image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated
using _validate_and_normalize_images() before calling this method.
device: Target device
generator: Random generator for reproducibility
Returns:
Tuple of:
- prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens
- prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains
the upsampled prior token ids for all condition images in that sample. None for t2i.
- source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape
(num_condition_images, 3) with upsampled grid info. None for t2i.
"""
device = device or self._execution_device
is_text_to_image = image is None or len(image) == 0
content = []
if image is not None:
for img in image:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
# Normalize prompt to list format
prompt_list = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt_list)
# Image is already normalized by _validate_and_normalize_images(): None or List[List[PIL.Image]]
is_text_to_image = image is None
# Build messages for each sample in the batch
all_messages = []
for idx, p in enumerate(prompt_list):
content = []
if not is_text_to_image:
for img in image[idx]:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": p})
all_messages.append([{"role": "user", "content": content}])
# Process with the processor (supports batch with left padding)
inputs = self.processor.apply_chat_template(
messages,
all_messages,
tokenize=True,
padding=True if batch_size > 1 else False,
target_h=height,
target_w=width,
return_dict=True,
@@ -286,44 +376,117 @@ class GlmImagePipeline(DiffusionPipeline):
).to(device)
image_grid_thw = inputs.get("image_grid_thw")
images_per_sample = inputs.get("images_per_sample")
# Determine number of condition images and grids per sample
num_condition_images = 0 if is_text_to_image else len(image[0])
if images_per_sample is not None:
num_grids_per_sample = images_per_sample[0].item()
else:
# Fallback for batch_size=1: total grids is for single sample
num_grids_per_sample = image_grid_thw.shape[0]
# Compute generation params (same for all samples in homogeneous batch)
first_sample_grids = image_grid_thw[:num_grids_per_sample]
max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image
)
# Generate source image tokens (prior_token_image_ids) for i2i mode
prior_token_image_ids = None
if image is not None:
prior_token_image_embed = self.vision_language_encoder.get_image_features(
inputs["pixel_values"], image_grid_thw[:-1]
)
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
prior_token_image_embed, image_grid_thw[:-1]
)
source_image_grid_thw = None
if not is_text_to_image:
# Extract source grids by selecting condition image indices (skip target grids)
# Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...]
# We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...]
source_indices = []
for sample_idx in range(batch_size):
base = sample_idx * num_grids_per_sample
source_indices.extend(range(base, base + num_condition_images))
source_grids = image_grid_thw[source_indices]
# For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
# max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
if len(source_grids) > 0:
prior_token_image_embed = self.vision_language_encoder.get_image_features(
inputs["pixel_values"], source_grids, return_dict=False
)
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens(
prior_token_image_embed, source_grids
)
# Upsample each source image's prior tokens to match VAE/DiT resolution
split_sizes = source_grids.prod(dim=-1).tolist()
prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes)
upsampled_prior_ids = []
for i, prior_ids in enumerate(prior_ids_per_source):
t, h, w = source_grids[i].tolist()
upsampled = self._upsample_token_ids(prior_ids, int(h), int(w))
upsampled_prior_ids.append(upsampled.squeeze(0))
prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0)
# Upsample grid dimensions for later splitting
upsampled_grids = source_grids.clone()
upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2
upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2
source_image_grid_thw = upsampled_grids
# Generate with AR model
# Set torch random seed from generator for reproducibility
# (transformers generate() doesn't accept generator parameter)
if generator is not None:
seed = generator.initial_seed()
torch.manual_seed(seed)
if device is not None and device.type == "cuda":
torch.cuda.manual_seed(seed)
outputs = self.vision_language_encoder.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
)
prior_token_ids_d32 = self._extract_large_image_tokens(
outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
)
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
# Extract and upsample prior tokens for each sample
# For left-padded inputs, generated tokens start after the padded input sequence
all_prior_token_ids = []
max_input_length = inputs["input_ids"].shape[-1]
for idx in range(batch_size):
# For left-padded sequences, generated tokens start at max_input_length
# (padding is on the left, so all sequences end at the same position)
prior_token_ids_d32 = self._extract_large_image_tokens(
outputs[idx : idx + 1], max_input_length, large_image_offset, token_h * token_w
)
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
all_prior_token_ids.append(prior_token_ids)
prior_token_ids = torch.cat(all_prior_token_ids, dim=0)
return prior_token_ids, prior_token_image_ids
# Split prior_token_image_ids and source_image_grid_thw into per-sample lists for easier consumption
prior_token_image_ids_per_sample = None
source_image_grid_thw_per_sample = None
if prior_token_image_ids is not None and source_image_grid_thw is not None:
# Split grids: each sample has num_condition_images grids
source_image_grid_thw_per_sample = list(torch.split(source_image_grid_thw, num_condition_images))
# Split prior_token_image_ids: tokens per sample may vary due to different image sizes
tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist()
tokens_per_sample = []
for i in range(batch_size):
start_idx = i * num_condition_images
end_idx = start_idx + num_condition_images
tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx]))
prior_token_image_ids_per_sample = list(torch.split(prior_token_image_ids, tokens_per_sample))
return prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample
def get_glyph_texts(self, prompt):
prompt = prompt[0] if isinstance(prompt, list) else prompt
ocr_texts = (
re.findall(r"'([^']*)'", prompt)
+ re.findall(r"“([^“”]*)”", prompt)
+ re.findall(r'"([^"]*)"', prompt)
+ re.findall(r"「([^「」]*)」", prompt)
)
return ocr_texts
"""Extract glyph texts from prompt(s). Returns a list of lists for batch processing."""
if isinstance(prompt, str):
prompt = [prompt]
all_ocr_texts = []
for p in prompt:
ocr_texts = (
re.findall(r"'([^']*)'", p)
+ re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p)
+ re.findall(r'"([^"]*)"', p)
+ re.findall(r"「([^「」]*)」", p)
)
all_ocr_texts.append(ocr_texts)
return all_ocr_texts
def _get_glyph_embeds(
self,
@@ -332,29 +495,51 @@ class GlmImagePipeline(DiffusionPipeline):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""Get glyph embeddings for each prompt in the batch."""
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
glyph_texts = self.get_glyph_texts(prompt)
input_ids = self.tokenizer(
glyph_texts if len(glyph_texts) > 0 else [""],
max_length=max_sequence_length,
truncation=True,
).input_ids
input_ids = [
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
]
max_length = max(len(input_ids_) for input_ids_ in input_ids)
attention_mask = torch.tensor(
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
)
input_ids = torch.tensor(
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
device=device,
)
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
# get_glyph_texts now returns a list of lists (one per prompt)
all_glyph_texts = self.get_glyph_texts(prompt)
all_glyph_embeds = []
for glyph_texts in all_glyph_texts:
if len(glyph_texts) == 0:
glyph_texts = [""]
input_ids = self.tokenizer(
glyph_texts,
max_length=max_sequence_length,
truncation=True,
).input_ids
input_ids = [
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
]
max_length = max(len(input_ids_) for input_ids_ in input_ids)
attention_mask = torch.tensor(
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
device=device,
)
input_ids = torch.tensor(
[
input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_))
for input_ids_ in input_ids
],
device=device,
)
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
all_glyph_embeds.append(glyph_embeds)
# Pad to same sequence length and stack (use left padding to match transformers)
max_seq_len = max(emb.size(1) for emb in all_glyph_embeds)
padded_embeds = []
for emb in all_glyph_embeds:
if emb.size(1) < max_seq_len:
pad = torch.zeros(emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype)
emb = torch.cat([pad, emb], dim=1) # left padding
padded_embeds.append(emb)
glyph_embeds = torch.cat(padded_embeds, dim=0)
return glyph_embeds.to(device=device, dtype=dtype)
def encode_prompt(
@@ -399,9 +584,9 @@ class GlmImagePipeline(DiffusionPipeline):
if prompt_embeds is None:
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
seq_len = prompt_embeds.size(1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# Repeat embeddings for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
# For GLM-Image, negative_prompt must be "" instead of None
if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -409,9 +594,8 @@ class GlmImagePipeline(DiffusionPipeline):
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
seq_len = negative_prompt_embeds.size(1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if num_images_per_prompt > 1:
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return prompt_embeds, negative_prompt_embeds
@@ -442,7 +626,9 @@ class GlmImagePipeline(DiffusionPipeline):
prompt_embeds=None,
negative_prompt_embeds=None,
prior_token_ids=None,
prior_image_token_ids=None,
prior_token_image_ids=None,
source_image_grid_thw=None,
image=None,
):
if (
height is not None
@@ -488,12 +674,24 @@ class GlmImagePipeline(DiffusionPipeline):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if (prior_token_ids is None and prior_image_token_ids is not None) or (
prior_token_ids is not None and prior_image_token_ids is None
):
# Validate prior token inputs: for i2i mode, all three must be provided together
# For t2i mode, only prior_token_ids is needed (prior_token_image_ids and source_image_grid_thw should be None)
prior_image_inputs = [prior_token_image_ids, source_image_grid_thw]
num_prior_image_inputs = sum(x is not None for x in prior_image_inputs)
if num_prior_image_inputs > 0 and num_prior_image_inputs < len(prior_image_inputs):
raise ValueError(
f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
"`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. "
f"Got prior_token_image_ids={prior_token_image_ids is not None}, "
f"source_image_grid_thw={source_image_grid_thw is not None}."
)
if num_prior_image_inputs > 0 and prior_token_ids is None:
raise ValueError(
"`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided."
)
if num_prior_image_inputs > 0 and image is None:
raise ValueError(
"`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided "
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
)
if prior_token_ids is not None and prompt_embeds is None:
@@ -545,7 +743,8 @@ class GlmImagePipeline(DiffusionPipeline):
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prior_token_ids: Optional[torch.FloatTensor] = None,
prior_image_token_ids: Optional[torch.Tensor] = None,
prior_token_image_ids: Optional[List[torch.Tensor]] = None,
source_image_grid_thw: Optional[List[torch.Tensor]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
@@ -598,7 +797,9 @@ class GlmImagePipeline(DiffusionPipeline):
prompt_embeds,
negative_prompt_embeds,
prior_token_ids,
prior_image_token_ids,
prior_token_image_ids,
source_image_grid_thw,
image,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
@@ -611,34 +812,47 @@ class GlmImagePipeline(DiffusionPipeline):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if batch_size != 1:
raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
device = self._execution_device
# 2. Preprocess image tokens and prompt tokens
if prior_token_ids is None:
prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
image=image,
height=height,
width=width,
device=device,
)
# 2. Validate and normalize image format
normalized_image = self._validate_and_normalize_images(image, batch_size)
# 3. Preprocess image
if image is not None:
preprocessed_condition_images = []
for img in image:
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
image_height = (image_height // multiple_of) * multiple_of
image_width = (image_width // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
preprocessed_condition_images.append(img)
height = height or image_height
width = width or image_width
image = preprocessed_condition_images
# 3. Generate prior tokens (batch mode)
# Get a single generator for AR model (use first if list provided)
ar_generator = generator[0] if isinstance(generator, list) else generator
if prior_token_ids is None:
prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample = (
self.generate_prior_tokens(
prompt=prompt,
image=normalized_image,
height=height,
width=width,
device=device,
generator=ar_generator,
)
)
else:
# User provided prior_token_ids directly (from generate_prior_tokens)
prior_token_image_ids_per_sample = prior_token_image_ids
source_image_grid_thw_per_sample = source_image_grid_thw
# 4. Preprocess images for VAE encoding
preprocessed_images = None
if normalized_image is not None:
preprocessed_images = []
for prompt_images in normalized_image:
prompt_preprocessed = []
for img in prompt_images:
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
image_height = (image_height // multiple_of) * multiple_of
image_width = (image_width // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
prompt_preprocessed.append(img)
height = height or image_height
width = width or image_width
preprocessed_images.append(prompt_preprocessed)
# 5. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
@@ -652,7 +866,7 @@ class GlmImagePipeline(DiffusionPipeline):
dtype=self.dtype,
)
# 4. Prepare latents and (optional) image kv cache
# 6. Prepare latents and (optional) image kv cache
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
@@ -666,7 +880,7 @@ class GlmImagePipeline(DiffusionPipeline):
)
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
if image is not None:
if normalized_image is not None:
kv_caches.set_mode("write")
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
@@ -674,29 +888,38 @@ class GlmImagePipeline(DiffusionPipeline):
latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
condition_latent = retrieve_latents(
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
)
condition_latent = (condition_latent - latents_mean) / latents_std
# Process each sample's condition images
for prompt_idx in range(batch_size):
prompt_images = preprocessed_images[prompt_idx]
prompt_prior_ids = prior_token_image_ids_per_sample[prompt_idx]
prompt_grid_thw = source_image_grid_thw_per_sample[prompt_idx]
# Do not remove.
# It would be use to run the reference image through a
# forward pass at timestep 0 and keep the KV cache.
_ = self.transformer(
hidden_states=condition_latent,
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
prior_token_id=condition_image_prior_token_id,
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
timestep=torch.zeros((1,), device=device),
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
kv_caches=kv_caches,
)
# Split this sample's prior_token_image_ids by each image's token count
split_sizes = prompt_grid_thw.prod(dim=-1).tolist()
prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes)
# Process each condition image for this sample
for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image):
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
condition_latent = retrieve_latents(
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
)
condition_latent = (condition_latent - latents_mean) / latents_std
# 6. Prepare additional timestep conditions
_ = self.transformer(
hidden_states=condition_latent,
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
prior_token_id=condition_image_prior_token_id,
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
timestep=torch.zeros((1,), device=device),
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
kv_caches=kv_caches,
)
# Move to next sample's cache slot
kv_caches.next_sample()
# 7. Prepare additional timestep conditions
target_size = (height, width)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
@@ -726,10 +949,13 @@ class GlmImagePipeline(DiffusionPipeline):
)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
# 8. Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# Repeat prior_token_ids for num_images_per_prompt
if num_images_per_prompt > 1:
prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0)
prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -742,7 +968,7 @@ class GlmImagePipeline(DiffusionPipeline):
timestep = t.expand(latents.shape[0]) - 1
if image is not None:
if prior_token_image_ids_per_sample is not None:
kv_caches.set_mode("read")
noise_pred_cond = self.transformer(
@@ -760,7 +986,7 @@ class GlmImagePipeline(DiffusionPipeline):
# perform guidance
if self.do_classifier_free_guidance:
if image is not None:
if prior_token_image_ids_per_sample is not None:
kv_caches.set_mode("skip")
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,

View File

@@ -262,6 +262,9 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -324,6 +324,9 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -305,6 +305,9 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -309,6 +309,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -321,6 +321,9 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs

View File

@@ -323,6 +323,9 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs

View File

@@ -305,6 +305,9 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -316,6 +316,9 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -328,6 +328,9 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):

View File

@@ -17,6 +17,51 @@ class Flux2AutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -276,3 +276,74 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_no_mask_2 = model(**inputs_no_mask)
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_all_ones_2 = model(**inputs_all_ones)
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_with_padding_2 = model(**inputs_with_padding)
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))

View File

@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import PIL
import pytest
from diffusers.modular_pipelines import (
Flux2KleinAutoBlocks,
Flux2KleinModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
params = frozenset(["prompt", "height", "width", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
@pytest.mark.skip(reason="batched inference is currently not supported")
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
return

View File

@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import PIL
import pytest
from diffusers.modular_pipelines import (
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
params = frozenset(["prompt", "height", "width", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
@pytest.mark.skip(reason="batched inference is currently not supported")
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
return

View File

@@ -169,7 +169,7 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# fmt: off
expected_slice = np.array(
[
0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
0.5849247, 0.50278825, 0.45747858, 0.45895284, 0.43804976, 0.47044256, 0.5239665, 0.47904694, 0.3323419, 0.38725388, 0.28505728, 0.3161863, 0.35026982, 0.37546024, 0.4090118, 0.46629113
]
)
# fmt: on
@@ -177,20 +177,109 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(image.shape, (3, 32, 32))
self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
@unittest.skip("Not supported.")
def test_inference_batch_single_identical(self):
# GLM-Image has batch_size=1 constraint due to AR model
pass
"""Test that batch=1 produces consistent results with the same seed."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
@unittest.skip("Not supported.")
def test_inference_batch_consistent(self):
# GLM-Image has batch_size=1 constraint due to AR model
pass
# Run twice with same seed
inputs1 = self.get_dummy_inputs(device, seed=42)
inputs2 = self.get_dummy_inputs(device, seed=42)
image1 = pipe(**inputs1).images[0]
image2 = pipe(**inputs2).images[0]
self.assertTrue(torch.allclose(image1, image2, atol=1e-4))
def test_inference_batch_multiple_prompts(self):
"""Test batch processing with multiple prompts."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(42)
height, width = 32, 32
inputs = {
"prompt": ["A photo of a cat", "A photo of a dog"],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
images = pipe(**inputs).images
# Should return 2 images
self.assertEqual(len(images), 2)
self.assertEqual(images[0].shape, (3, 32, 32))
self.assertEqual(images[1].shape, (3, 32, 32))
@unittest.skip("Not supported.")
def test_num_images_per_prompt(self):
# GLM-Image has batch_size=1 constraint due to AR model
pass
"""Test generating multiple images per prompt."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(42)
height, width = 32, 32
inputs = {
"prompt": "A photo of a cat",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
"num_images_per_prompt": 2,
}
images = pipe(**inputs).images
# Should return 2 images for single prompt
self.assertEqual(len(images), 2)
self.assertEqual(images[0].shape, (3, 32, 32))
self.assertEqual(images[1].shape, (3, 32, 32))
def test_batch_with_num_images_per_prompt(self):
"""Test batch prompts with num_images_per_prompt > 1."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(42)
height, width = 32, 32
inputs = {
"prompt": ["A photo of a cat", "A photo of a dog"],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
"num_images_per_prompt": 2,
}
images = pipe(**inputs).images
# Should return 4 images (2 prompts × 2 images per prompt)
self.assertEqual(len(images), 4)
@unittest.skip("Needs to be revisited.")
def test_encode_prompt_works_in_isolation(self):