1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make flux ready for mellon (#12419)

* make flux ready for mellon

* up

* Apply suggestions from code review

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-10-06 13:15:54 +05:30
committed by GitHub
parent ce90f9b2db
commit 7f3e9b8695
3 changed files with 32 additions and 8 deletions

View File

@@ -252,11 +252,13 @@ class FluxInputStep(ModularPipelineBlocks):
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
),
@@ -279,11 +281,13 @@ class FluxInputStep(ModularPipelineBlocks):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="pooled text embeddings used to guide the image generation",
),
# TODO: support negative embeddings?

View File

@@ -181,6 +181,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("joint_attention_kwargs"),
]
@@ -189,16 +190,19 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"text_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
@@ -404,6 +408,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=block_state.max_sequence_length,
lora_scale=block_state.text_encoder_lora_scale,
)

View File

@@ -84,9 +84,9 @@ class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
block_names = ["text2image", "img2img"]
block_trigger_inputs = [None, "image_latents"]
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
@@ -124,16 +124,32 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step support text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
)
# text2image
class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [
FluxTextEncoderStep,
FluxAutoVaeEncoderStep,
FluxAutoBeforeDenoiseStep,
FluxAutoDenoiseStep,
FluxCoreDenoiseStep,
FluxAutoDecodeStep,
]
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
block_names = ["text_encoder", "image_encoder", "denoise", "decode"]
@property
def description(self):
@@ -171,8 +187,7 @@ AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep),
("denoise", FluxAutoDenoiseStep),
("denoise", FluxCoreDenoiseStep),
("decode", FluxAutoDecodeStep),
]
)