1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

flux2-dev work in modular setting

This commit is contained in:
yiyixuxu
2026-01-21 08:06:32 +01:00
parent f49c68cecf
commit 1c500c8eeb
2 changed files with 53 additions and 22 deletions

View File

@@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import PIL.Image
from typing import List
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2PrepareGuidanceStep,
Flux2PrepareImageLatentsStep,
@@ -42,7 +46,6 @@ Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
]
)
@@ -73,35 +76,56 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
)
Flux2BeforeDenoiseBlocks = InsertableDict(
Flux2CoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2BeforeDenoiseBlocks.values()
block_names = Flux2BeforeDenoiseBlocks.keys()
block_classes = Flux2CoreDenoiseBlocks.values()
block_names = Flux2CoreDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
return (
"Core denoise step that performs the denoising process for Flux2-dev.\n"
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -110,11 +134,8 @@ AUTO_BLOCKS = InsertableDict(
REMOTE_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2RemoteTextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -134,6 +155,16 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
TEXT2IMAGE_BLOCKS = InsertableDict(
[

View File

@@ -43,9 +43,10 @@ from .inputs import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
###
### VAE encoder
###
################
# VAE encoder
################
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
@@ -105,9 +106,8 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein.\n"
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"