1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyi@huggingface.co
2026-01-20 01:31:41 +00:00
parent 618a8a9897
commit fb2cb18f73
8 changed files with 48 additions and 39 deletions

View File

@@ -413,10 +413,10 @@ else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
@@ -1149,10 +1149,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,

View File

@@ -84,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2ModularPipeline, Flux2KleinModularPipeline
from .flux2 import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -101,7 +101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -12,24 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import inspect
from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
from ...guiders import ClassifierFreeGuidance
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
if is_torch_xla_available():
@@ -136,7 +136,8 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
return components, block_state
# sane as Flux2 but guidance=None
# same as Flux2LoopDenoiser but guidance=None
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@@ -308,7 +309,7 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
)
),
]
@torch.no_grad()
@@ -368,7 +369,6 @@ class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
# perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
@@ -491,6 +491,7 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]

View File

@@ -15,13 +15,13 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen3ForCausalLM, Qwen2TokenizerFast
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -245,11 +245,9 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
return components, state
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@@ -284,7 +282,6 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
@@ -390,7 +387,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
)
else:
block_state.negative_prompt_embeds = None
self.set_block_state(state, block_state)
return components, state

View File

@@ -100,7 +100,9 @@ class Flux2TextInputStep(ModularPipelineBlocks):
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)

View File

@@ -22,7 +22,7 @@ from .before_denoise import (
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import Flux2KleinDenoiseStep, Flux2KleinBaseDenoiseStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinTextEncoderStep,
Flux2VaeEncoderStep,
@@ -55,7 +55,6 @@ class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2KleinVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
@@ -71,9 +70,8 @@ class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
)
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
@@ -89,7 +87,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)."
@@ -105,7 +103,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
[
("input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
@@ -115,11 +113,12 @@ Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
]
)
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
@@ -134,12 +133,16 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
)
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinCoreDenoiseStep(), Flux2DecodeStep()]
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
@property
def description(self):
return (
@@ -149,12 +152,16 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
)
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep()]
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
@property
def description(self):
return (

View File

@@ -13,12 +13,12 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
from typing import Optional, Dict, Any
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -59,8 +59,6 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2-Klein.
@@ -71,7 +69,6 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
@@ -105,7 +102,6 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
@@ -113,4 +109,4 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
return requires_unconditional_embeds