mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove guidannce to its own block
This commit is contained in:
@@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
@@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
block_state.device,
|
||||
device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
@@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -349,6 +335,60 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
@@ -511,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the guidance scale tensor for Flux2 inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -30,6 +30,68 @@ logger = logging.get_logger(__name__)
|
||||
class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -19,6 +19,7 @@ from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2PrepareGuidanceStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
@@ -76,6 +77,7 @@ Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
@@ -139,6 +141,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
@@ -155,6 +158,7 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
|
||||
@@ -12,13 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import PIL.Image
|
||||
from typing import List
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
@@ -31,6 +35,7 @@ from .encoders import (
|
||||
from .inputs import (
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
Flux2KleinBaseTextInputStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,7 +106,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein.\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
@@ -110,14 +115,24 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
@@ -134,14 +149,23 @@ class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
|
||||
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
###
|
||||
@@ -165,6 +189,16 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
@@ -183,3 +217,13 @@ class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user