mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] add modular support for Flux I2I (#12086)
* start * encoder. * up * up * up * up * up * up
This commit is contained in:
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -103,6 +104,62 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Adapted from the original implementation.
|
||||
def prepare_latents_img2img(
|
||||
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
latent_channels = vae.config.latent_channels
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != latent_channels:
|
||||
image_latents = _encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = scheduler.scale_noise(image_latents, timestep, noise)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
@@ -125,6 +182,55 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# Cannot use "# Copied from" because it introduces weird indentation errors.
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
def _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
vae_scale_factor,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
sigmas,
|
||||
device,
|
||||
):
|
||||
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
return timesteps, num_inference_steps, sigmas, guidance
|
||||
|
||||
|
||||
class FluxInputStep(PipelineBlock):
|
||||
model_name = "flux"
|
||||
|
||||
@@ -234,18 +340,20 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
)
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -264,38 +372,131 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
|
||||
latents = block_state.latents
|
||||
image_seq_len = latents.shape[1]
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
|
||||
)
|
||||
if components.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("strength", default=0.6),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam(
|
||||
"latent_timestep",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image generation",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
|
||||
def get_timesteps(scheduler, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
if hasattr(scheduler, "set_begin_index"):
|
||||
scheduler.set_begin_index(t_start * scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
scheduler, num_inference_steps, block_state.strength, block_state.device
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
block_state.guidance = guidance
|
||||
|
||||
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(PipelineBlock):
|
||||
model_name = "flux"
|
||||
|
||||
@@ -305,7 +506,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
return "Prepare latents step that prepares the latents for the text-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
@@ -402,10 +603,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
@@ -418,3 +619,95 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the latents for the image-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam(
|
||||
"latent_image_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from the image sequence needed for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# TODO: implement `check_inputs`
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
if block_state.latents is None:
|
||||
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
block_state.image_latents,
|
||||
block_state.latent_timestep,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -226,5 +226,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `FluxLoopDenoiser`\n"
|
||||
" - `FluxLoopAfterDenoiser`\n"
|
||||
"This block supports text2image tasks."
|
||||
"This block supports both text2image and img2img tasks."
|
||||
)
|
||||
|
||||
@@ -19,7 +19,10 @@ import regex as re
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -50,6 +53,110 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class FluxVaeEncoderStep(PipelineBlock):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Encoder step that encode the input image into a latent representation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("image", required=True), InputParam("height"), InputParam("width")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components.vae, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxTextEncoderStep(PipelineBlock):
|
||||
model_name = "flux"
|
||||
|
||||
@@ -297,7 +404,7 @@ class FluxTextEncoderStep(PipelineBlock):
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1, # hardcoded for now.
|
||||
num_images_per_prompt=1, # TODO: hardcoded for now.
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,16 +15,38 @@
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxInputStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import FluxTextEncoderStep
|
||||
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# before_denoise: text2vid
|
||||
# vae encoder (run before before_denoise)
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img, img2img
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
FluxInputStep,
|
||||
@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2vid,)
|
||||
# before_denoise: img2img
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxBeforeDenoiseStep]
|
||||
block_names = ["text2image"]
|
||||
block_trigger_inputs = [None]
|
||||
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
|
||||
block_names = ["text2image", "img2img"]
|
||||
block_trigger_inputs = [None, "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image tasks."
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
|
||||
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
|
||||
|
||||
|
||||
# text2image
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
|
||||
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
|
||||
block_classes = [
|
||||
FluxTextEncoderStep,
|
||||
FluxAutoVaeEncoderStep,
|
||||
FluxAutoBeforeDenoiseStep,
|
||||
FluxAutoDenoiseStep,
|
||||
FluxAutoDecodeStep,
|
||||
]
|
||||
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
|
||||
)
|
||||
|
||||
|
||||
@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
# Setting it after preparation of latents because we rely on `latents`
|
||||
# to calculate `img_seq_len` for `shift`.
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxVaeEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxAutoVaeEncoderStep),
|
||||
("before_denoise", FluxAutoBeforeDenoiseStep),
|
||||
("denoise", FluxAutoDenoiseStep),
|
||||
("decode", FluxAutoDecodeStep),
|
||||
@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user