1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove guidannce to its own block

This commit is contained in:
yiyixuxu
2026-01-21 00:59:56 +01:00
parent e13377e841
commit 5c1fc4489f
4 changed files with 211 additions and 22 deletions

View File

@@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
@@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
scheduler = components.scheduler
@@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
@@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
@@ -349,6 +335,60 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
]
@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)
return torch.stack(out_ids)
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
@@ -511,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares the guidance scale tensor for Flux2 inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("guidance_scale", default=4.0),
InputParam("num_images_per_prompt", default=1),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state

View File

@@ -30,6 +30,68 @@ logger = logging.get_logger(__name__)
class Flux2TextInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return (

View File

@@ -19,6 +19,7 @@ from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2PrepareGuidanceStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
@@ -76,6 +77,7 @@ Flux2BeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
]
)
@@ -139,6 +141,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
@@ -155,6 +158,7 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),

View File

@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import PIL.Image
from typing import List
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2KleinBaseRoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
@@ -31,6 +35,7 @@ from .encoders import (
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
Flux2KleinBaseTextInputStep,
)
@@ -101,7 +106,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein.\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
@@ -110,14 +115,24 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("input", Flux2KleinBaseTextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
@@ -134,14 +149,23 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
###
@@ -165,6 +189,16 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
@@ -183,3 +217,13 @@ class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]