diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 71228a5598..52ec30c536 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -413,10 +413,10 @@ else: _import_structure["modular_pipelines"].extend( [ "Flux2AutoBlocks", - "Flux2ModularPipeline", "Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks", "Flux2KleinModularPipeline", + "Flux2ModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -1149,10 +1149,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .modular_pipelines import ( Flux2AutoBlocks, - Flux2ModularPipeline, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2KleinModularPipeline, + Flux2ModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 099e86a553..823a3d263e 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -84,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline - from .flux2 import Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2ModularPipeline, Flux2KleinModularPipeline + from .flux2 import ( + Flux2AutoBlocks, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, + Flux2ModularPipeline, + ) from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 64ced29bdd..fb97a56fb0 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -101,7 +101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, ) - from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline + from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index 84cad52ab7..b2e1b41dde 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple import inspect +from typing import Any, List, Tuple import torch from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models import Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging -from ...guiders import ClassifierFreeGuidance from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec -from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline if is_torch_xla_available(): @@ -136,7 +136,8 @@ class Flux2LoopDenoiser(ModularPipelineBlocks): return components, block_state -# sane as Flux2 but guidance=None + +# same as Flux2LoopDenoiser but guidance=None class Flux2KleinLoopDenoiser(ModularPipelineBlocks): model_name = "flux2-klein" @@ -308,7 +309,7 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): InputParam( kwargs_type="denoiser_input_fields", description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ) + ), ] @torch.no_grad() @@ -368,7 +369,6 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): # perform guidance block_state.noise_pred = components.guider(guider_state)[0] - return components, block_state @@ -491,6 +491,7 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): "This block supports both text-to-image and image-conditioned generation." ) + class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 5c06746aa2..1d9e56bdf0 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -15,13 +15,13 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen3ForCausalLM, Qwen2TokenizerFast +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec -from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -245,11 +245,9 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): return components, state - class Flux2KleinTextEncoderStep(ModularPipelineBlocks): model_name = "flux2-klein" - @property def description(self) -> str: return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" @@ -284,7 +282,6 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks): type_hint=torch.Tensor, description="Text embeddings from qwen3 used to guide the image generation", ), - OutputParam( "negative_prompt_embeds", type_hint=torch.Tensor, @@ -390,7 +387,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks): ) else: block_state.negative_prompt_embeds = None - + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index 0de0040c39..cc078c8262 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -100,7 +100,9 @@ class Flux2TextInputStep(ModularPipelineBlocks): if block_state.negative_prompt_embeds is not None: _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 2f89106b13..1dd63a6123 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -22,7 +22,7 @@ from .before_denoise import ( Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep -from .denoise import Flux2KleinDenoiseStep, Flux2KleinBaseDenoiseStep +from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep from .encoders import ( Flux2KleinTextEncoderStep, Flux2VaeEncoderStep, @@ -55,7 +55,6 @@ class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." - class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [Flux2KleinVaeEncoderSequentialStep] block_names = ["img_conditioning"] @@ -71,9 +70,8 @@ class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): ) - Flux2KleinCoreDenoiseBlocks = InsertableDict( - [ + [ ("input", Flux2TextInputStep()), ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), @@ -89,7 +87,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): block_classes = Flux2KleinCoreDenoiseBlocks.values() block_names = Flux2KleinCoreDenoiseBlocks.keys() - + @property def description(self): return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." @@ -105,7 +103,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( - [ + [ ("input", Flux2TextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("prepare_image_latents", Flux2PrepareImageLatentsStep()), @@ -115,11 +113,12 @@ Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( ] ) + class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): model_name = "flux2-klein" block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() - + @property def description(self): return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." @@ -134,12 +133,16 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): ) - class Flux2KleinAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" - block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinCoreDenoiseStep(), Flux2DecodeStep()] + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinCoreDenoiseStep(), + Flux2DecodeStep(), + ] block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] - + @property def description(self): return ( @@ -149,12 +152,16 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks): ) - class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" - block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep()] + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinBaseCoreDenoiseStep(), + Flux2DecodeStep(), + ] block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] - + @property def description(self): return ( diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index e37dafcfce..29fbeba07c 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -13,12 +13,12 @@ # limitations under the License. +from typing import Any, Dict, Optional + from ...loaders import Flux2LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline -from typing import Optional, Dict, Any - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -59,8 +59,6 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): return num_channels_latents - - class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): """ A ModularPipeline for Flux2-Klein. @@ -71,7 +69,6 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): default_blocks_name = "Flux2KleinBaseAutoBlocks" def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: return "Flux2KleinAutoBlocks" else: @@ -105,7 +102,6 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): @property def requires_unconditional_embeds(self): - if hasattr(self.config, "is_distilled") and self.config.is_distilled: return False @@ -113,4 +109,4 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): if hasattr(self, "guider") and self.guider is not None: requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 - return requires_unconditional_embeds \ No newline at end of file + return requires_unconditional_embeds