diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 24b9c12db6..71228a5598 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -414,6 +414,9 @@ else: [ "Flux2AutoBlocks", "Flux2ModularPipeline", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", + "Flux2KleinModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -1147,6 +1150,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .modular_pipelines import ( Flux2AutoBlocks, Flux2ModularPipeline, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index e64db23f38..099e86a553 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -54,7 +54,10 @@ else: ] _import_structure["flux2"] = [ "Flux2AutoBlocks", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", "Flux2ModularPipeline", + "Flux2KleinModularPipeline", ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", @@ -81,7 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline - from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline + from .flux2 import Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2ModularPipeline, Flux2KleinModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 21a41c1fe9..64ced29bdd 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -43,7 +43,7 @@ else: "Flux2ProcessImagesInputStep", "Flux2TextInputStep", ] - _import_structure["modular_blocks"] = [ + _import_structure["modular_blocks_flux2"] = [ "ALL_BLOCKS", "AUTO_BLOCKS", "REMOTE_AUTO_BLOCKS", @@ -54,7 +54,8 @@ else: "Flux2BeforeDenoiseStep", "Flux2VaeEncoderSequentialStep", ] - _import_structure["modular_pipeline"] = ["Flux2ModularPipeline"] + _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] + _import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -85,7 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Flux2ProcessImagesInputStep, Flux2TextInputStep, ) - from .modular_blocks import ( + from .modular_blocks_flux2 import ( ALL_BLOCKS, AUTO_BLOCKS, IMAGE_CONDITIONED_BLOCKS, @@ -96,7 +97,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Flux2BeforeDenoiseStep, Flux2VaeEncoderSequentialStep, ) - from .modular_pipeline import Flux2ModularPipeline + from .modular_blocks_flux2_klein import ( + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + ) + from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py index 42624688ad..e1001924c7 100644 --- a/src/diffusers/modular_pipelines/flux2/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -353,7 +353,7 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam(name="prompt_embeds", required=True), - InputParam(name="latent_ids"), + InputParam(name="negative_prompt_embeds", required=False), ] @property @@ -366,10 +366,10 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks): description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", ), OutputParam( - name="latent_ids", + name="negative_txt_ids", kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, - description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.", + description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", ), ] @@ -399,6 +399,11 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks): block_state.txt_ids = self._prepare_text_ids(prompt_embeds) block_state.txt_ids = block_state.txt_ids.to(device) + block_state.negative_txt_ids = None + if block_state.negative_prompt_embeds is not None: + block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) + block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index c12eca65c6..84cad52ab7 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -13,20 +13,23 @@ # limitations under the License. from typing import Any, List, Tuple +import inspect import torch +from ...configuration_utils import FrozenDict from ...models import Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging +from ...guiders import ClassifierFreeGuidance from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import Flux2ModularPipeline +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec +from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline if is_torch_xla_available(): @@ -133,6 +136,241 @@ class Flux2LoopDenoiser(ModularPipelineBlocks): return components, block_state +# sane as Flux2 but guidance=None +class Flux2KleinLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# support CFG for Flux2-Klein base model +class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", Flux2Transformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "negative_prompt_embeds", + required=False, + type_hint=torch.Tensor, + description="Negative text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "negative_txt_ids", + required=False, + type_hint=torch.Tensor, + description="4D position IDs for negative text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ) + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "txt_ids": ( + getattr(block_state, "txt_ids", None), + getattr(block_state, "negative_txt_ids", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + cond_kwargs.update(additional_cond_kwargs) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] + components.guider.cleanup_models(components.transformer) + + # perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + + return components, block_state + class Flux2LoopAfterDenoiser(ModularPipelineBlocks): model_name = "flux2" @@ -220,6 +458,8 @@ class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 ) + block_state.additional_cond_kwargs = {} + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) @@ -250,3 +490,34 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): " - `Flux2LoopAfterDenoiser`\n" "This block supports both text-to-image and image-conditioned generation." ) + +class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinBaseLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 6cb0e3bf0a..5c06746aa2 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -15,13 +15,13 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen3ForCausalLM, Qwen2TokenizerFast from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import Flux2ModularPipeline +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec +from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -79,10 +79,8 @@ class Flux2TextEncoderStep(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), InputParam("max_sequence_length", type_hint=int, default=512, required=False), InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False), - InputParam("joint_attention_kwargs"), ] @property @@ -99,14 +97,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks): @staticmethod def check_inputs(block_state): prompt = block_state.prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " - "Please make sure to only forward one of the two." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @staticmethod @@ -165,10 +156,6 @@ class Flux2TextEncoderStep(ModularPipelineBlocks): block_state.device = components._execution_device - if block_state.prompt_embeds is not None: - self.set_block_state(state, block_state) - return components, state - prompt = block_state.prompt if prompt is None: prompt = "" @@ -205,7 +192,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), ] @property @@ -222,15 +208,8 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): @staticmethod def check_inputs(block_state): prompt = block_state.prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " - "Please make sure to only forward one of the two." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") @torch.no_grad() def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: @@ -244,10 +223,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): block_state.device = components._execution_device - if block_state.prompt_embeds is not None: - self.set_block_state(state, block_state) - return components, state - prompt = block_state.prompt if prompt is None: prompt = "" @@ -270,6 +245,156 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): return components, state + +class Flux2KleinTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Negative text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + if components.requires_unconditional_embeds: + negative_prompt = "" + block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + else: + block_state.negative_prompt_embeds = None + + self.set_block_state(state, block_state) + return components, state + + class Flux2VaeEncoderStep(ModularPipelineBlocks): model_name = "flux2" diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index c9e337fb0b..0de0040c39 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -47,7 +47,14 @@ class Flux2TextInputStep(ModularPipelineBlocks): required=True, kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, - description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.", + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + required=False, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", ), ] @@ -70,6 +77,12 @@ class Flux2TextInputStep(ModularPipelineBlocks): kwargs_type="denoiser_input_fields", description="Text embeddings used to guide the image generation", ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative text embeddings used to guide the image generation", + ), ] @torch.no_grad() @@ -85,6 +98,13 @@ class Flux2TextInputStep(ModularPipelineBlocks): block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 ) + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py similarity index 100% rename from src/diffusers/modular_pipelines/flux2/modular_blocks.py rename to src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py new file mode 100644 index 0000000000..2f89106b13 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -0,0 +1,164 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep +from .denoise import Flux2KleinDenoiseStep, Flux2KleinBaseDenoiseStep +from .encoders import ( + Flux2KleinTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +Flux2KleinVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ] +) + + +class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2KleinVaeEncoderBlocks.values() + block_names = Flux2KleinVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + + +class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2KleinVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + + +Flux2KleinCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ] +) + + +class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + + block_classes = Flux2KleinCoreDenoiseBlocks.values() + block_names = Flux2KleinCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." + return ( + "Core denoise step that performs the denoising process for Flux2-Klein.\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n" + ) + + +Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ] +) + +class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." + return ( + "Core denoise step that performs the denoising process for Flux2-Klein (base model).\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" + ) + + + +class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinCoreDenoiseStep(), Flux2DecodeStep()] + block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) + + + +class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep()] + block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index 3e497f3b1e..e37dafcfce 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -17,6 +17,8 @@ from ...loaders import Flux2LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline +from typing import Optional, Dict, Any + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,3 +57,60 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): if getattr(self, "transformer", None): num_channels_latents = self.transformer.config.in_channels // 4 return num_channels_latents + + + + +class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2-Klein. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" + + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + + if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: + return "Flux2KleinAutoBlocks" + else: + return "Flux2KleinBaseAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d857fd0409..98ede73c21 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -59,6 +59,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( ("flux", "FluxModularPipeline"), ("flux-kontext", "FluxKontextModularPipeline"), ("flux2", "Flux2ModularPipeline"), + ("flux2-klein", "Flux2KleinModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),