diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index b2ecfd1aa6..914f48a450 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -17,11 +17,14 @@ from typing import List, Optional, Union import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan from ...utils import is_ftfy_available, logging +from ...video_processor import VideoProcessor from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline @@ -51,6 +54,20 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class WanTextEncoderStep(PipelineBlock): model_name = "wan" @@ -240,3 +257,233 @@ class WanTextEncoderStep(PipelineBlock): # Add outputs self.set_block_state(state, block_state) return components, state + + +class WanImageEncodeStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step to compute image embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModel), + ComponentSpec("image_processor", CLIPImageProcessor), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "image", + required=True, + description="The input image to condition the generation on for first-frame conditioned video generation.", + ), + InputParam( + "last_image", + required=False, + description="The last image to condition the generation on for last-frame conditioned video generation.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "encoder_hidden_states_image", + type_hint=torch.Tensor, + description="image embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if not isinstance(block_state.image, PipelineImageInput): + raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.") + if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput): + raise ValueError( + f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}." + ) + + @staticmethod + def encode_image( + components, + image: PipelineImageInput, + device: torch.device, + ): + image = components.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = components.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input images + image = block_state.image + if block_state.last_image is not None: + image = [block_state.image, block_state.last_image] + + block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class WanVaeEncoderStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "VAE encode step that encodes the input image/last_image to latents for conditioning the video generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("last_image", required=False), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("num_channels_latents", type_hint=int), + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latent_condition", + type_hint=torch.Tensor, + description="The latents representing the reference first-frame/last-frame for conditioned video generation.", + ) + ] + + def _encode_vae_image( + self, + components: WanModularPipeline, + batch_size: int, + height: int, + width: int, + num_frames: int, + image: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + last_image: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ): + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(device, dtype) + + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=dtype) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1) + + return latent_condition + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.num_channels_latents = self.vae.config.z_dim + block_state.batch_size = ( + block_state.batch_size if block_state.batch_size is not None else block_state.image.shape[0] + ) + + block_state.image = self.video_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width + ).to(block_state.device, dtype=torch.float32) + if block_state.last_image is not None: + block_state.last_image = self.video_processor.preprocess( + block_state.last_image, height=block_state.height, width=block_state.width + ).to(block_state.device, dtype=torch.float32) + + block_state.latent_condition = self._encode_vae_image( + components, + batch_size=block_state.batch_size, + height=block_state.height, + width=block_state.width, + num_frames=block_state.num_frames, + image=block_state.image, + device=block_state.device, + dtype=block_state.dtype, + last_image=block_state.last_image, + generator=block_state.generator, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 5f4c1a9835..21d6eb5c9e 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -22,12 +22,27 @@ from .before_denoise import ( ) from .decoders import WanDecodeStep from .denoise import WanDenoiseStep -from .encoders import WanTextEncoderStep +from .encoders import WanTextEncoderStep, WanVaeEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class WanAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [WanVaeEncoderStep] + block_names = ["img2vid"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n" + + " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided." + + " - if `image` is provided, this step will be skipped." + ) + + # before_denoise: text2vid class WanBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ @@ -97,6 +112,7 @@ class WanAutoDecodeStep(AutoPipelineBlocks): class WanAutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, + WanAutoVaeEncoderStep, WanAutoBeforeDenoiseStep, WanAutoDenoiseStep, WanAutoDecodeStep, @@ -128,10 +144,23 @@ TEXT2VIDEO_BLOCKS = InsertableDict( ) +IMAGE2VIDEO_BLOCKS = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("input", WanInputStep), + ("image_encoder", WanVaeEncoderStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", WanDenoiseStep), + ("decode", WanDecodeStep), + ] +) + + AUTO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("before_denoise", WanAutoBeforeDenoiseStep), + ("image_encoder", WanAutoVaeEncoderStep)("before_denoise", WanAutoBeforeDenoiseStep), ("denoise", WanAutoDenoiseStep), ("decode", WanAutoDecodeStep), ]