mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into apply-lora-scale-decorator
This commit is contained in:
@@ -413,6 +413,9 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
@@ -1146,6 +1149,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
|
||||
@@ -143,41 +143,86 @@ class GlmImageAdaLayerNormZero(nn.Module):
|
||||
|
||||
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model."""
|
||||
"""KV cache for GlmImage model.
|
||||
Supports per-sample caching for batch processing where each sample may have different condition images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.k_caches: List[Optional[torch.Tensor]] = []
|
||||
self.v_caches: List[Optional[torch.Tensor]] = []
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
self.current_sample_idx: int = 0 # Current sample index for writing
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
"""Store KV cache for the current sample."""
|
||||
# k, v shape: (1, seq_len, num_heads, head_dim)
|
||||
if len(self.k_caches) <= self.current_sample_idx:
|
||||
# First time storing for this sample
|
||||
self.k_caches.append(k)
|
||||
self.v_caches.append(v)
|
||||
else:
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=1)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=1)
|
||||
# Append to existing cache for this sample (multiple condition images)
|
||||
self.k_caches[self.current_sample_idx] = torch.cat([self.k_caches[self.current_sample_idx], k], dim=1)
|
||||
self.v_caches[self.current_sample_idx] = torch.cat([self.v_caches[self.current_sample_idx], v], dim=1)
|
||||
|
||||
def get(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache.shape[0] != k.shape[0]:
|
||||
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
|
||||
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
|
||||
else:
|
||||
k_cache_expanded = self.k_cache
|
||||
v_cache_expanded = self.v_cache
|
||||
"""Get combined KV cache for all samples in the batch.
|
||||
|
||||
k_cache = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_cache = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_cache, v_cache
|
||||
Args:
|
||||
k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim)
|
||||
v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim)
|
||||
Returns:
|
||||
Combined key and value tensors with cached values prepended.
|
||||
"""
|
||||
batch_size = k.shape[0]
|
||||
num_cached_samples = len(self.k_caches)
|
||||
if num_cached_samples == 0:
|
||||
return k, v
|
||||
if num_cached_samples == 1:
|
||||
# Single cache, expand for all batch samples (shared condition images)
|
||||
k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1)
|
||||
v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1)
|
||||
elif num_cached_samples == batch_size:
|
||||
# Per-sample cache, concatenate along batch dimension
|
||||
k_cache_expanded = torch.cat(self.k_caches, dim=0)
|
||||
v_cache_expanded = torch.cat(self.v_caches, dim=0)
|
||||
else:
|
||||
# Mismatch: try to handle by repeating the caches
|
||||
# This handles cases like num_images_per_prompt > 1
|
||||
repeat_factor = batch_size // num_cached_samples
|
||||
if batch_size % num_cached_samples == 0:
|
||||
k_cache_list = []
|
||||
v_cache_list = []
|
||||
for i in range(num_cached_samples):
|
||||
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
|
||||
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
|
||||
k_cache_expanded = torch.cat(k_cache_list, dim=0)
|
||||
v_cache_expanded = torch.cat(v_cache_list, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. "
|
||||
f"Batch size must be a multiple of the number of cached samples."
|
||||
)
|
||||
|
||||
k_combined = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_combined = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_combined, v_combined
|
||||
|
||||
def clear(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
self.mode = None
|
||||
self.current_sample_idx = 0
|
||||
|
||||
def next_sample(self):
|
||||
"""Move to the next sample for writing."""
|
||||
self.current_sample_idx += 1
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches."""
|
||||
"""Container for all layers' KV caches.
|
||||
Supports per-sample caching for batch processing where each sample may have different condition images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
@@ -192,6 +237,12 @@ class GlmImageKVCache:
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def next_sample(self):
|
||||
"""Move to the next sample for writing. Call this after processing
|
||||
all condition images for one batch sample."""
|
||||
for cache in self.caches:
|
||||
cache.next_sample()
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
@@ -54,7 +54,10 @@ else:
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -81,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||
from .flux2 import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2TextInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks"] = [
|
||||
_import_structure["modular_blocks_flux2"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"REMOTE_AUTO_BLOCKS",
|
||||
@@ -51,10 +51,11 @@ else:
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2BeforeDenoiseStep",
|
||||
"Flux2CoreDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -85,7 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
from .modular_blocks_flux2 import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
@@ -93,10 +94,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2BeforeDenoiseStep,
|
||||
Flux2CoreDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
from .modular_blocks_flux2_klein import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
@@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
block_state.device,
|
||||
device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
@@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -353,7 +339,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="latent_ids"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -365,12 +350,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latent_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -403,6 +382,72 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="negative_prompt_embeds", required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
block_state.negative_txt_ids = None
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
|
||||
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -506,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the guidance scale tensor for Flux2 inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,29 +29,16 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
return "Step that unpacks the latents from the denoising step"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -70,9 +57,9 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
"latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoise latents from denoising step, unpacked with position IDs.",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -107,6 +94,62 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
block_state.latents = latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
"""Convert patchified latents back to regular format."""
|
||||
@@ -121,26 +164,20 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.images = block_state.latents
|
||||
else:
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
latents = block_state.latents
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -16,6 +16,8 @@ from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
@@ -25,8 +27,8 @@ from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -134,6 +136,229 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# same as Flux2LoopDenoiser but guidance=None
|
||||
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# support CFG for Flux2-Klein base model
|
||||
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"negative_txt_ids",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for negative text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"txt_ids": (
|
||||
getattr(block_state, "txt_ids", None),
|
||||
getattr(block_state, "negative_txt_ids", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
# perform guidance
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -250,3 +475,35 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinBaseLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
@@ -15,13 +15,15 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -79,10 +81,8 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -99,14 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -165,10 +158,6 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -205,7 +194,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -222,15 +210,8 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -244,10 +225,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -270,6 +247,289 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
else:
|
||||
block_state.negative_prompt_embeds = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -89,6 +89,90 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -12,16 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2PrepareGuidanceStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
@@ -41,7 +47,6 @@ Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -72,33 +77,56 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
Flux2CoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||
block_classes = Flux2CoreDenoiseBlocks.values()
|
||||
block_names = Flux2CoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-dev.\n"
|
||||
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -107,10 +135,8 @@ AUTO_BLOCKS = InsertableDict(
|
||||
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -130,6 +156,16 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
@@ -137,8 +173,10 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -152,8 +190,10 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,232 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
|
||||
from .encoders import (
|
||||
Flux2KleinBaseTextEncoderStep,
|
||||
Flux2KleinTextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2KleinBaseTextInputStep,
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
################
|
||||
# VAE encoder
|
||||
################
|
||||
|
||||
Flux2KleinVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2KleinVaeEncoderBlocks.values()
|
||||
block_names = Flux2KleinVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
|
||||
|
||||
|
||||
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2KleinVaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
###
|
||||
### Core denoise
|
||||
###
|
||||
|
||||
Flux2KleinCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2KleinDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = Flux2KleinCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
###
|
||||
### Auto blocks
|
||||
###
|
||||
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinBaseTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinBaseCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -55,3 +57,56 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
@@ -59,6 +59,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
|
||||
@@ -52,6 +52,15 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
@@ -359,7 +368,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
@@ -549,6 +558,7 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
num_latent_conditional_frames: int = 2,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
@@ -614,6 +624,10 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
num_latent_conditional_frames (`int`, defaults to `2`):
|
||||
Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames
|
||||
extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1
|
||||
for Image2World-like behavior (single frame conditioning).
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -692,19 +706,38 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
if num_latent_conditional_frames not in [1, 2]:
|
||||
raise ValueError(
|
||||
f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}"
|
||||
)
|
||||
|
||||
frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1
|
||||
|
||||
total_input_frames = len(video)
|
||||
|
||||
if total_input_frames < frames_to_extract:
|
||||
raise ValueError(
|
||||
f"Input video has only {total_input_frames} frames but Video2World requires at least "
|
||||
f"{frames_to_extract} frames for conditioning."
|
||||
)
|
||||
|
||||
num_frames_in = frames_to_extract
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
# For Video2World: extract last frames_to_extract frames from input, then pad
|
||||
if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]:
|
||||
video = video[:, :, -num_frames_in:, :, :]
|
||||
|
||||
num_frames_out = num_frames
|
||||
|
||||
if video.shape[2] < num_frames_out:
|
||||
n_pad_frames = num_frames_out - num_frames_in
|
||||
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
|
||||
n_pad_frames = num_frames_out - video.shape[2]
|
||||
last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W]
|
||||
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
|
||||
video = torch.cat((video, pad_frames), dim=2)
|
||||
|
||||
|
||||
@@ -49,6 +49,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -300,7 +308,7 @@ class Cosmos2TextToImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -50,6 +50,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -319,7 +327,7 @@ class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -49,6 +49,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -285,7 +293,7 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -50,6 +50,14 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
||||
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
"Overall, the video is of poor quality."
|
||||
)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
@@ -331,7 +339,7 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
||||
@@ -260,25 +260,115 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
token_ids = token_ids.view(1, -1)
|
||||
return token_ids
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_normalize_images(
|
||||
image: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]],
|
||||
batch_size: int,
|
||||
) -> Optional[List[List[PIL.Image.Image]]]:
|
||||
"""
|
||||
Validate and normalize image inputs to List[List[PIL.Image]].
|
||||
|
||||
Rules:
|
||||
- batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length
|
||||
- batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]])
|
||||
- Other formats raise ValueError
|
||||
|
||||
Args:
|
||||
image: Input images in various formats
|
||||
batch_size: Number of prompts in the batch
|
||||
|
||||
Returns:
|
||||
Normalized images as List[List[PIL.Image]], or None if no images provided
|
||||
"""
|
||||
if image is None or len(image) == 0:
|
||||
return None
|
||||
|
||||
first_element = image[0]
|
||||
|
||||
if batch_size == 1:
|
||||
# Legacy format: List[PIL.Image] -> [[img1, img2, ...]]
|
||||
if not isinstance(first_element, (list, tuple)):
|
||||
return [list(image)]
|
||||
# Already in List[List[PIL.Image]] format
|
||||
if len(image) != 1:
|
||||
raise ValueError(
|
||||
f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}."
|
||||
)
|
||||
return [list(image[0])]
|
||||
|
||||
# batch_size > 1: must be List[List[PIL.Image]]
|
||||
if not isinstance(first_element, (list, tuple)):
|
||||
raise ValueError(
|
||||
f"For batch_size > 1, images must be List[List[PIL.Image]] format. "
|
||||
f"Got List[{type(first_element).__name__}] instead. "
|
||||
f"Each prompt requires its own list of condition images."
|
||||
)
|
||||
|
||||
if len(image) != batch_size:
|
||||
raise ValueError(f"Number of image lists ({len(image)}) must match batch size ({batch_size}).")
|
||||
|
||||
# Validate homogeneous: all sublists must have same length
|
||||
num_input_images_per_prompt = len(image[0])
|
||||
for idx, imgs in enumerate(image):
|
||||
if len(imgs) != num_input_images_per_prompt:
|
||||
raise ValueError(
|
||||
f"All prompts must have the same number of condition images. "
|
||||
f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images."
|
||||
)
|
||||
|
||||
return [list(imgs) for imgs in image]
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
image: Optional[List[List[PIL.Image.Image]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
"""
|
||||
Generate prior tokens for the DiT model using the AR model.
|
||||
|
||||
Args:
|
||||
prompt: Single prompt or list of prompts
|
||||
height: Target image height
|
||||
width: Target image width
|
||||
image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated
|
||||
using _validate_and_normalize_images() before calling this method.
|
||||
device: Target device
|
||||
generator: Random generator for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens
|
||||
- prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains
|
||||
the upsampled prior token ids for all condition images in that sample. None for t2i.
|
||||
- source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape
|
||||
(num_condition_images, 3) with upsampled grid info. None for t2i.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
# Normalize prompt to list format
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt_list)
|
||||
|
||||
# Image is already normalized by _validate_and_normalize_images(): None or List[List[PIL.Image]]
|
||||
is_text_to_image = image is None
|
||||
# Build messages for each sample in the batch
|
||||
all_messages = []
|
||||
for idx, p in enumerate(prompt_list):
|
||||
content = []
|
||||
if not is_text_to_image:
|
||||
for img in image[idx]:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": p})
|
||||
all_messages.append([{"role": "user", "content": content}])
|
||||
# Process with the processor (supports batch with left padding)
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
all_messages,
|
||||
tokenize=True,
|
||||
padding=True if batch_size > 1 else False,
|
||||
target_h=height,
|
||||
target_w=width,
|
||||
return_dict=True,
|
||||
@@ -286,44 +376,117 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
).to(device)
|
||||
|
||||
image_grid_thw = inputs.get("image_grid_thw")
|
||||
images_per_sample = inputs.get("images_per_sample")
|
||||
|
||||
# Determine number of condition images and grids per sample
|
||||
num_condition_images = 0 if is_text_to_image else len(image[0])
|
||||
if images_per_sample is not None:
|
||||
num_grids_per_sample = images_per_sample[0].item()
|
||||
else:
|
||||
# Fallback for batch_size=1: total grids is for single sample
|
||||
num_grids_per_sample = image_grid_thw.shape[0]
|
||||
|
||||
# Compute generation params (same for all samples in homogeneous batch)
|
||||
first_sample_grids = image_grid_thw[:num_grids_per_sample]
|
||||
max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
|
||||
image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
|
||||
image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image
|
||||
)
|
||||
|
||||
# Generate source image tokens (prior_token_image_ids) for i2i mode
|
||||
prior_token_image_ids = None
|
||||
if image is not None:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], image_grid_thw[:-1]
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, image_grid_thw[:-1]
|
||||
)
|
||||
source_image_grid_thw = None
|
||||
if not is_text_to_image:
|
||||
# Extract source grids by selecting condition image indices (skip target grids)
|
||||
# Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...]
|
||||
# We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...]
|
||||
source_indices = []
|
||||
for sample_idx in range(batch_size):
|
||||
base = sample_idx * num_grids_per_sample
|
||||
source_indices.extend(range(base, base + num_condition_images))
|
||||
source_grids = image_grid_thw[source_indices]
|
||||
|
||||
# For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
|
||||
# max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
|
||||
if len(source_grids) > 0:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], source_grids, return_dict=False
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, source_grids
|
||||
)
|
||||
# Upsample each source image's prior tokens to match VAE/DiT resolution
|
||||
split_sizes = source_grids.prod(dim=-1).tolist()
|
||||
prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes)
|
||||
upsampled_prior_ids = []
|
||||
for i, prior_ids in enumerate(prior_ids_per_source):
|
||||
t, h, w = source_grids[i].tolist()
|
||||
upsampled = self._upsample_token_ids(prior_ids, int(h), int(w))
|
||||
upsampled_prior_ids.append(upsampled.squeeze(0))
|
||||
prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0)
|
||||
# Upsample grid dimensions for later splitting
|
||||
upsampled_grids = source_grids.clone()
|
||||
upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2
|
||||
upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2
|
||||
source_image_grid_thw = upsampled_grids
|
||||
|
||||
# Generate with AR model
|
||||
# Set torch random seed from generator for reproducibility
|
||||
# (transformers generate() doesn't accept generator parameter)
|
||||
if generator is not None:
|
||||
seed = generator.initial_seed()
|
||||
torch.manual_seed(seed)
|
||||
if device is not None and device.type == "cuda":
|
||||
torch.cuda.manual_seed(seed)
|
||||
outputs = self.vision_language_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
# Extract and upsample prior tokens for each sample
|
||||
# For left-padded inputs, generated tokens start after the padded input sequence
|
||||
all_prior_token_ids = []
|
||||
max_input_length = inputs["input_ids"].shape[-1]
|
||||
for idx in range(batch_size):
|
||||
# For left-padded sequences, generated tokens start at max_input_length
|
||||
# (padding is on the left, so all sequences end at the same position)
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs[idx : idx + 1], max_input_length, large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
all_prior_token_ids.append(prior_token_ids)
|
||||
prior_token_ids = torch.cat(all_prior_token_ids, dim=0)
|
||||
|
||||
return prior_token_ids, prior_token_image_ids
|
||||
# Split prior_token_image_ids and source_image_grid_thw into per-sample lists for easier consumption
|
||||
prior_token_image_ids_per_sample = None
|
||||
source_image_grid_thw_per_sample = None
|
||||
if prior_token_image_ids is not None and source_image_grid_thw is not None:
|
||||
# Split grids: each sample has num_condition_images grids
|
||||
source_image_grid_thw_per_sample = list(torch.split(source_image_grid_thw, num_condition_images))
|
||||
# Split prior_token_image_ids: tokens per sample may vary due to different image sizes
|
||||
tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist()
|
||||
tokens_per_sample = []
|
||||
for i in range(batch_size):
|
||||
start_idx = i * num_condition_images
|
||||
end_idx = start_idx + num_condition_images
|
||||
tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx]))
|
||||
prior_token_image_ids_per_sample = list(torch.split(prior_token_image_ids, tokens_per_sample))
|
||||
|
||||
return prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample
|
||||
|
||||
def get_glyph_texts(self, prompt):
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", prompt)
|
||||
+ re.findall(r"“([^“”]*)”", prompt)
|
||||
+ re.findall(r'"([^"]*)"', prompt)
|
||||
+ re.findall(r"「([^「」]*)」", prompt)
|
||||
)
|
||||
return ocr_texts
|
||||
"""Extract glyph texts from prompt(s). Returns a list of lists for batch processing."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
all_ocr_texts = []
|
||||
for p in prompt:
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", p)
|
||||
+ re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p)
|
||||
+ re.findall(r'"([^"]*)"', p)
|
||||
+ re.findall(r"「([^「」]*)」", p)
|
||||
)
|
||||
all_ocr_texts.append(ocr_texts)
|
||||
return all_ocr_texts
|
||||
|
||||
def _get_glyph_embeds(
|
||||
self,
|
||||
@@ -332,29 +495,51 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""Get glyph embeddings for each prompt in the batch."""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts if len(glyph_texts) > 0 else [""],
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
# get_glyph_texts now returns a list of lists (one per prompt)
|
||||
all_glyph_texts = self.get_glyph_texts(prompt)
|
||||
|
||||
all_glyph_embeds = []
|
||||
for glyph_texts in all_glyph_texts:
|
||||
if len(glyph_texts) == 0:
|
||||
glyph_texts = [""]
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_))
|
||||
for input_ids_ in input_ids
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
all_glyph_embeds.append(glyph_embeds)
|
||||
|
||||
# Pad to same sequence length and stack (use left padding to match transformers)
|
||||
max_seq_len = max(emb.size(1) for emb in all_glyph_embeds)
|
||||
padded_embeds = []
|
||||
for emb in all_glyph_embeds:
|
||||
if emb.size(1) < max_seq_len:
|
||||
pad = torch.zeros(emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype)
|
||||
emb = torch.cat([pad, emb], dim=1) # left padding
|
||||
padded_embeds.append(emb)
|
||||
|
||||
glyph_embeds = torch.cat(padded_embeds, dim=0)
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
def encode_prompt(
|
||||
@@ -399,9 +584,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
# Repeat embeddings for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# For GLM-Image, negative_prompt must be "" instead of None
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@@ -409,9 +594,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
if num_images_per_prompt > 1:
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@@ -442,7 +626,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prior_token_ids=None,
|
||||
prior_image_token_ids=None,
|
||||
prior_token_image_ids=None,
|
||||
source_image_grid_thw=None,
|
||||
image=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
@@ -488,12 +674,24 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if (prior_token_ids is None and prior_image_token_ids is not None) or (
|
||||
prior_token_ids is not None and prior_image_token_ids is None
|
||||
):
|
||||
# Validate prior token inputs: for i2i mode, all three must be provided together
|
||||
# For t2i mode, only prior_token_ids is needed (prior_token_image_ids and source_image_grid_thw should be None)
|
||||
prior_image_inputs = [prior_token_image_ids, source_image_grid_thw]
|
||||
num_prior_image_inputs = sum(x is not None for x in prior_image_inputs)
|
||||
if num_prior_image_inputs > 0 and num_prior_image_inputs < len(prior_image_inputs):
|
||||
raise ValueError(
|
||||
f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
|
||||
f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
|
||||
"`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. "
|
||||
f"Got prior_token_image_ids={prior_token_image_ids is not None}, "
|
||||
f"source_image_grid_thw={source_image_grid_thw is not None}."
|
||||
)
|
||||
if num_prior_image_inputs > 0 and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided."
|
||||
)
|
||||
if num_prior_image_inputs > 0 and image is None:
|
||||
raise ValueError(
|
||||
"`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided "
|
||||
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
@@ -545,7 +743,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prior_token_ids: Optional[torch.FloatTensor] = None,
|
||||
prior_image_token_ids: Optional[torch.Tensor] = None,
|
||||
prior_token_image_ids: Optional[List[torch.Tensor]] = None,
|
||||
source_image_grid_thw: Optional[List[torch.Tensor]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -598,7 +797,9 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prior_token_ids,
|
||||
prior_image_token_ids,
|
||||
prior_token_image_ids,
|
||||
source_image_grid_thw,
|
||||
image,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
@@ -611,34 +812,47 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Preprocess image tokens and prompt tokens
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
)
|
||||
# 2. Validate and normalize image format
|
||||
normalized_image = self._validate_and_normalize_images(image, batch_size)
|
||||
|
||||
# 3. Preprocess image
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
image = preprocessed_condition_images
|
||||
# 3. Generate prior tokens (batch mode)
|
||||
# Get a single generator for AR model (use first if list provided)
|
||||
ar_generator = generator[0] if isinstance(generator, list) else generator
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample = (
|
||||
self.generate_prior_tokens(
|
||||
prompt=prompt,
|
||||
image=normalized_image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
generator=ar_generator,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User provided prior_token_ids directly (from generate_prior_tokens)
|
||||
prior_token_image_ids_per_sample = prior_token_image_ids
|
||||
source_image_grid_thw_per_sample = source_image_grid_thw
|
||||
|
||||
# 4. Preprocess images for VAE encoding
|
||||
preprocessed_images = None
|
||||
if normalized_image is not None:
|
||||
preprocessed_images = []
|
||||
for prompt_images in normalized_image:
|
||||
prompt_preprocessed = []
|
||||
for img in prompt_images:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
prompt_preprocessed.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
preprocessed_images.append(prompt_preprocessed)
|
||||
|
||||
# 5. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
@@ -652,7 +866,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare latents and (optional) image kv cache
|
||||
# 6. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
@@ -666,7 +880,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if image is not None:
|
||||
if normalized_image is not None:
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
@@ -674,29 +888,38 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
# Process each sample's condition images
|
||||
for prompt_idx in range(batch_size):
|
||||
prompt_images = preprocessed_images[prompt_idx]
|
||||
prompt_prior_ids = prior_token_image_ids_per_sample[prompt_idx]
|
||||
prompt_grid_thw = source_image_grid_thw_per_sample[prompt_idx]
|
||||
|
||||
# Do not remove.
|
||||
# It would be use to run the reference image through a
|
||||
# forward pass at timestep 0 and keep the KV cache.
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
# Split this sample's prior_token_image_ids by each image's token count
|
||||
split_sizes = prompt_grid_thw.prod(dim=-1).tolist()
|
||||
prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes)
|
||||
# Process each condition image for this sample
|
||||
for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
|
||||
# 6. Prepare additional timestep conditions
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
# Move to next sample's cache slot
|
||||
kv_caches.next_sample()
|
||||
|
||||
# 7. Prepare additional timestep conditions
|
||||
target_size = (height, width)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
@@ -726,10 +949,13 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 8. Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# Repeat prior_token_ids for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
|
||||
prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -742,7 +968,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if image is not None:
|
||||
if prior_token_image_ids_per_sample is not None:
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
@@ -760,7 +986,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if image is not None:
|
||||
if prior_token_image_ids_per_sample is not None:
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
|
||||
@@ -262,6 +262,9 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -324,6 +324,9 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -305,6 +305,9 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -309,6 +309,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -321,6 +321,9 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
|
||||
|
||||
@@ -323,6 +323,9 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
||||
|
||||
@@ -305,6 +305,9 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -316,6 +316,9 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -328,6 +328,9 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):
|
||||
|
||||
@@ -17,6 +17,51 @@ class Flux2AutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -276,3 +276,74 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
|
||||
# Test 1: Run with None mask (no padding, all tokens are valid)
|
||||
inputs_no_mask = inputs.copy()
|
||||
inputs_no_mask["encoder_hidden_states_mask"] = None
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_all_ones = model(**inputs_all_ones)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
|
||||
|
||||
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_with_padding = model(**inputs_with_padding)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -169,7 +169,7 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
|
||||
0.5849247, 0.50278825, 0.45747858, 0.45895284, 0.43804976, 0.47044256, 0.5239665, 0.47904694, 0.3323419, 0.38725388, 0.28505728, 0.3161863, 0.35026982, 0.37546024, 0.4090118, 0.46629113
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
@@ -177,20 +177,109 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(image.shape, (3, 32, 32))
|
||||
self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
"""Test that batch=1 produces consistent results with the same seed."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_consistent(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
# Run twice with same seed
|
||||
inputs1 = self.get_dummy_inputs(device, seed=42)
|
||||
inputs2 = self.get_dummy_inputs(device, seed=42)
|
||||
|
||||
image1 = pipe(**inputs1).images[0]
|
||||
image2 = pipe(**inputs2).images[0]
|
||||
|
||||
self.assertTrue(torch.allclose(image1, image2, atol=1e-4))
|
||||
|
||||
def test_inference_batch_multiple_prompts(self):
|
||||
"""Test batch processing with multiple prompts."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": ["A photo of a cat", "A photo of a dog"],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 2 images
|
||||
self.assertEqual(len(images), 2)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
self.assertEqual(images[1].shape, (3, 32, 32))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_num_images_per_prompt(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
"""Test generating multiple images per prompt."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": "A photo of a cat",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"num_images_per_prompt": 2,
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 2 images for single prompt
|
||||
self.assertEqual(len(images), 2)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
self.assertEqual(images[1].shape, (3, 32, 32))
|
||||
|
||||
def test_batch_with_num_images_per_prompt(self):
|
||||
"""Test batch prompts with num_images_per_prompt > 1."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": ["A photo of a cat", "A photo of a dog"],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"num_images_per_prompt": 2,
|
||||
}
|
||||
|
||||
images = pipe(**inputs).images
|
||||
|
||||
# Should return 4 images (2 prompts × 2 images per prompt)
|
||||
self.assertEqual(len(images), 4)
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
|
||||
Reference in New Issue
Block a user