1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[modular] i2i and t2i support for kontext modular (#12454)

* up

* get ready

* fix import

* up

* up
This commit is contained in:
Sayak Paul
2025-10-10 18:10:17 +05:30
committed by GitHub
parent a9df12ab45
commit 693d8a3a52
11 changed files with 659 additions and 46 deletions

View File

@@ -386,6 +386,8 @@ else:
_import_structure["modular_pipelines"].extend(
[
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
"FluxModularPipeline",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
@@ -1050,6 +1052,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,
FluxModularPipeline,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,

View File

@@ -46,7 +46,12 @@ else:
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
@@ -65,7 +70,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxModularPipeline
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -25,14 +25,18 @@ else:
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"AUTO_BLOCKS_KONTEXT",
"FLUX_KONTEXT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -45,13 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
AUTO_BLOCKS_KONTEXT,
FLUX_KONTEXT_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxModularPipeline
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
else:
import sys

View File

@@ -118,15 +118,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# TODO: align this with Qwen patchifier
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
@@ -398,16 +389,15 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# TODO: move packing latents code to a patchifier
# TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
@@ -557,3 +547,73 @@ class FluxRoPEInputsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux-kontext"
@property
def description(self) -> str:
return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="image_height"),
InputParam(name="image_width"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="prompt_embeds"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
),
OutputParam(
name="img_ids",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the image latents, used for RoPE calculation.",
),
]
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device, dtype = prompt_embeds.device, prompt_embeds.dtype
block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
device=prompt_embeds.device, dtype=prompt_embeds.dtype
)
img_ids = None
if (
getattr(block_state, "image_height", None) is not None
and getattr(block_state, "image_width", None) is not None
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
img_ids[..., 0] = 1
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
if img_ids is not None:
latent_ids = torch.cat([latent_ids, img_ids], dim=0)
block_state.img_ids = latent_ids
self.set_block_state(state, block_state)
return components, state

View File

@@ -109,6 +109,96 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
return components, block_state
class FluxKontextLoopDenoiser(ModularPipelineBlocks):
model_name = "flux-kontext"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", FluxTransformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents for Flux Kontext. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `FluxDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Prompt embeddings",
),
InputParam(
"pooled_prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pooled prompt embeddings",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE",
),
InputParam(
"img_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from latent sequence needed for RoPE",
),
]
@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents
image_latents = block_state.image_latents
if image_latents is not None:
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs,
txt_ids=block_state.txt_ids,
img_ids=block_state.img_ids,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"
@@ -221,3 +311,20 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)
class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
model_name = "flux-kontext"
block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `FluxKontextLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)

View File

@@ -20,7 +20,7 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
@@ -83,11 +83,11 @@ def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.G
class FluxProcessImagesInputStep(ModularPipelineBlocks):
model_name = "Flux"
model_name = "flux"
@property
def description(self) -> str:
return "Image Preprocess step. Resizing is needed in Flux Kontext (will be implemented later.)"
return "Image Preprocess step."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -106,9 +106,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="processed_image"),
]
return [OutputParam(name="processed_image")]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
@@ -142,13 +140,80 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
return components, state
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux-kontext"
def __init__(self, _auto_resize=True):
self._auto_resize = _auto_resize
super().__init__()
@property
def description(self) -> str:
return (
"Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n"
"Kontext works as a T2I model, too, in case no input image is provided."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("image")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam(name="processed_image")]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState):
from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
block_state = self.get_block_state(state)
images = block_state.image
if images is None:
block_state.processed_image = None
else:
multiple_of = components.image_processor.config.vae_scale_factor
if not is_valid_image_imagelist(images):
raise ValueError(f"Images must be image or list of images but are {type(images)}")
if is_valid_image(images):
images = [images]
img = images[0]
image_height, image_width = components.image_processor.get_default_height_width(img)
aspect_ratio = image_width / image_height
if self._auto_resize:
# Kontext is trained on specific resolutions, using one of them is recommended
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
images = components.image_processor.resize(images, image_height, image_width)
block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width)
self.set_block_state(state, block_state)
return components, state
class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
self,
input_name: str = "processed_image",
output_name: str = "image_latents",
self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"
):
"""Initialize a VAE encoder step for converting images to latent representations.
@@ -160,6 +225,7 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
Examples: "processed_image" or "processed_control_image"
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
Examples: "image_latents" or "control_image_latents"
sample_mode (str, optional): Sampling mode to be used.
Examples:
# Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
@@ -170,6 +236,7 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
"""
self._image_input_name = input_name
self._image_latents_output_name = output_name
self.sample_mode = sample_mode
super().__init__()
@property
@@ -183,7 +250,7 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
inputs = [InputParam(self._image_input_name), InputParam("generator")]
return inputs
@property
@@ -199,16 +266,20 @@ class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.vae.dtype
image = getattr(block_state, self._image_input_name)
image = image.to(device=device, dtype=dtype)
# Encode image into latents
image_latents = encode_vae_image(image=image, vae=components.vae, generator=block_state.generator)
setattr(block_state, self._image_latents_output_name, image_latents)
if image is None:
setattr(block_state, self._image_latents_output_name, None)
else:
device = components._execution_device
dtype = components.vae.dtype
image = image.to(device=device, dtype=dtype)
# Encode image into latents
image_latents = encode_vae_image(
image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode
)
setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)

View File

@@ -17,6 +17,7 @@ from typing import List
import torch
from ...pipelines import FluxPipeline
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import InputParam, OutputParam
@@ -25,6 +26,9 @@ from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_t
from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__)
class FluxTextInputStep(ModularPipelineBlocks):
model_name = "flux"
@@ -234,3 +238,122 @@ class FluxInputsDynamicStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
model_name = "flux-kontext"
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents
# Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
if not hasattr(block_state, "image_height"):
block_state.image_height = height
if not hasattr(block_state, "image_width"):
block_state.image_width = width
# 2. Patchify the image latent tensor
# TODO: Implement patchifier for Flux.
latent_height, latent_width = image_latent_tensor.shape[2:]
image_latent_tensor = FluxPipeline._pack_latents(
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
)
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class FluxKontextSetResolutionStep(ModularPipelineBlocks):
model_name = "flux-kontext"
def description(self):
return (
"Determines the height and width to be used during the subsequent computations.\n"
"It should always be placed _before_ the latent preparation step."
)
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="max_area", type_hint=int, default=1024**2),
]
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"),
OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
height = block_state.height or components.default_height
width = block_state.width or components.default_width
self.check_inputs(height, width, components.vae_scale_factor)
original_height, original_width = height, width
max_area = block_state.max_area
aspect_ratio = width / height
width = round((max_area * aspect_ratio) ** 0.5)
height = round((max_area / aspect_ratio) ** 0.5)
multiple_of = components.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if height != original_height or width != original_width:
logger.warning(
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
)
block_state.height = height
block_state.width = width
self.set_block_state(state, block_state)
return components, state

View File

@@ -18,14 +18,25 @@ from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
FluxKontextRoPEInputsStep,
FluxPrepareLatentsStep,
FluxRoPEInputsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep
from .encoders import FluxProcessImagesInputStep, FluxTextEncoderStep, FluxVaeEncoderDynamicStep
from .inputs import FluxInputsDynamicStep, FluxTextInputStep
from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
from .encoders import (
FluxKontextProcessImagesInputStep,
FluxProcessImagesInputStep,
FluxTextEncoderStep,
FluxVaeEncoderDynamicStep,
)
from .inputs import (
FluxInputsDynamicStep,
FluxKontextInputsDynamicStep,
FluxKontextSetResolutionStep,
FluxTextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -33,10 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# vae encoder (run before before_denoise)
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
[
("preprocess", FluxProcessImagesInputStep()),
("encode", FluxVaeEncoderDynamicStep()),
]
[("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
)
@@ -66,6 +74,39 @@ class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
)
# Flux Kontext vae encoder (run before before_denoise)
FluxKontextVaeEncoderBlocks = InsertableDict(
[("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
)
class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = FluxKontextVaeEncoderBlocks.values()
block_names = FluxKontextVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxKontextVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is not provided, step will be skipped."
)
# before_denoise: text2img
FluxBeforeDenoiseBlocks = InsertableDict(
[
@@ -107,6 +148,7 @@ class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@@ -121,6 +163,44 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
)
# before_denoise: FluxKontext
FluxKontextBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep()),
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
]
)
class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = FluxKontextBeforeDenoiseBlocks.values()
block_names = FluxKontextBeforeDenoiseBlocks.keys()
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step\n"
"for img2img/text2img task for Flux Kontext."
)
class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
@@ -136,6 +216,23 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
)
# denoise: Flux Kontext
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents for Flux Kontext. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
block_classes = [FluxDecodeStep]
@@ -165,7 +262,7 @@ class FluxImg2ImgInputStep(SequentialPipelineBlocks):
" - update height/width based `image_latents`, patchify `image_latents`."
class FluxImageAutoInputStep(AutoPipelineBlocks):
class FluxAutoInputStep(AutoPipelineBlocks):
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@@ -180,16 +277,59 @@ class FluxImageAutoInputStep(AutoPipelineBlocks):
)
# inputs: Flux Kontext
FluxKontextBlocks = InsertableDict(
[
("set_resolution", FluxKontextSetResolutionStep()),
("text_inputs", FluxTextInputStep()),
("additional_inputs", FluxKontextInputsDynamicStep()),
]
)
class FluxKontextInputStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = FluxKontextBlocks.values()
block_names = FluxKontextBlocks.keys()
@property
def description(self):
return (
"Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
class FluxKontextAutoInputStep(AutoPipelineBlocks):
block_classes = [FluxKontextInputStep, FluxTextInputStep]
# block_classes = [FluxKontextInputStep]
block_names = ["img2img", "text2img"]
# block_names = ["img2img"]
block_trigger_inputs = ["image_latents", None]
# block_trigger_inputs = ["image_latents"]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
)
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = [FluxImageAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
@@ -198,6 +338,24 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
)
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
)
# Auto blocks (text2image and img2img)
AUTO_BLOCKS = InsertableDict(
[
@@ -208,6 +366,15 @@ AUTO_BLOCKS = InsertableDict(
]
)
AUTO_BLOCKS_KONTEXT = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("image_encoder", FluxKontextAutoVaeEncoderStep()),
("denoise", FluxKontextCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
class FluxAutoBlocks(SequentialPipelineBlocks):
model_name = "flux"
@@ -224,6 +391,13 @@ class FluxAutoBlocks(SequentialPipelineBlocks):
)
class FluxKontextAutoBlocks(FluxAutoBlocks):
model_name = "flux-kontext"
block_classes = AUTO_BLOCKS_KONTEXT.values()
block_names = AUTO_BLOCKS_KONTEXT.keys()
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
@@ -250,4 +424,23 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
]
)
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
FLUX_KONTEXT_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
("input", FluxKontextInputStep()),
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep()),
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
("denoise", FluxKontextDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"auto": AUTO_BLOCKS,
"auto_kontext": AUTO_BLOCKS_KONTEXT,
"kontext": FLUX_KONTEXT_BLOCKS,
}

View File

@@ -55,3 +55,13 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
class FluxKontextModularPipeline(FluxModularPipeline):
"""
A ModularPipeline for Flux Kontext.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "FluxKontextAutoBlocks"

View File

@@ -57,6 +57,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),

View File

@@ -17,6 +17,36 @@ class FluxAutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class FluxKontextAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxKontextModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]