mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -17,11 +17,14 @@ from typing import List, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
@@ -51,6 +54,20 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class WanTextEncoderStep(PipelineBlock):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -240,3 +257,233 @@ class WanTextEncoderStep(PipelineBlock):
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageEncodeStep(PipelineBlock):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Encoder step to compute image embeddings to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_encoder", CLIPVisionModel),
|
||||
ComponentSpec("image_processor", CLIPImageProcessor),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
required=True,
|
||||
description="The input image to condition the generation on for first-frame conditioned video generation.",
|
||||
),
|
||||
InputParam(
|
||||
"last_image",
|
||||
required=False,
|
||||
description="The last image to condition the generation on for last-frame conditioned video generation.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"encoder_hidden_states_image",
|
||||
type_hint=torch.Tensor,
|
||||
description="image embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if not isinstance(block_state.image, PipelineImageInput):
|
||||
raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.")
|
||||
if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput):
|
||||
raise ValueError(
|
||||
f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def encode_image(
|
||||
components,
|
||||
image: PipelineImageInput,
|
||||
device: torch.device,
|
||||
):
|
||||
image = components.image_processor(images=image, return_tensors="pt").to(device)
|
||||
image_embeds = components.image_encoder(**image, output_hidden_states=True)
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# Get inputs and intermediates
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# Encode input images
|
||||
image = block_state.image
|
||||
if block_state.last_image is not None:
|
||||
image = [block_state.image, block_state.last_image]
|
||||
|
||||
block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device)
|
||||
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanVaeEncoderStep(PipelineBlock):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"VAE encode step that encodes the input image/last_image to latents for conditioning the video generation"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("last_image", required=False),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_channels_latents", type_hint=int),
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latent_condition",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference first-frame/last-frame for conditioned video generation.",
|
||||
)
|
||||
]
|
||||
|
||||
def _encode_vae_image(
|
||||
self,
|
||||
components: WanModularPipeline,
|
||||
batch_size: int,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
image: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
last_image: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(device, dtype)
|
||||
|
||||
image = image.unsqueeze(2)
|
||||
if last_image is None:
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
else:
|
||||
last_image = last_image.unsqueeze(2)
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
|
||||
dim=2,
|
||||
)
|
||||
video_condition = video_condition.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latent_condition = latent_condition.to(dtype)
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
if last_image is None:
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
else:
|
||||
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(latent_condition.device)
|
||||
latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
|
||||
|
||||
return latent_condition
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.num_channels_latents = self.vae.config.z_dim
|
||||
block_state.batch_size = (
|
||||
block_state.batch_size if block_state.batch_size is not None else block_state.image.shape[0]
|
||||
)
|
||||
|
||||
block_state.image = self.video_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width
|
||||
).to(block_state.device, dtype=torch.float32)
|
||||
if block_state.last_image is not None:
|
||||
block_state.last_image = self.video_processor.preprocess(
|
||||
block_state.last_image, height=block_state.height, width=block_state.width
|
||||
).to(block_state.device, dtype=torch.float32)
|
||||
|
||||
block_state.latent_condition = self._encode_vae_image(
|
||||
components,
|
||||
batch_size=block_state.batch_size,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
num_frames=block_state.num_frames,
|
||||
image=block_state.image,
|
||||
device=block_state.device,
|
||||
dtype=block_state.dtype,
|
||||
last_image=block_state.last_image,
|
||||
generator=block_state.generator,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -22,12 +22,27 @@ from .before_denoise import (
|
||||
)
|
||||
from .decoders import WanDecodeStep
|
||||
from .denoise import WanDenoiseStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
from .encoders import WanTextEncoderStep, WanVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanVaeEncoderStep]
|
||||
block_names = ["img2vid"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
|
||||
+ " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided."
|
||||
+ " - if `image` is provided, this step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2vid
|
||||
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
@@ -97,6 +112,7 @@ class WanAutoDecodeStep(AutoPipelineBlocks):
|
||||
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoVaeEncoderStep,
|
||||
WanAutoBeforeDenoiseStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanAutoDecodeStep,
|
||||
@@ -128,10 +144,23 @@ TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanInputStep),
|
||||
("image_encoder", WanVaeEncoderStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", WanDenoiseStep),
|
||||
("decode", WanDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("before_denoise", WanAutoBeforeDenoiseStep),
|
||||
("image_encoder", WanAutoVaeEncoderStep)("before_denoise", WanAutoBeforeDenoiseStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanAutoDecodeStep),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user