diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index a5ad844776..1a6f2af59a 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -106,35 +106,38 @@ class WanVACETransformerBlock(nn.Module): ) -> torch.Tensor: if self.proj_in is not None: control_hidden_states = self.proj_in(control_hidden_states) - hidden_states = hidden_states + control_hidden_states - else: - hidden_states = control_hidden_states + control_hidden_states = control_hidden_states + hidden_states shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb.float() ).chunk(6, dim=1) # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( + control_hidden_states + ) attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) - hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) # 2. Cross-attention - norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) - hidden_states = hidden_states + attn_output + control_hidden_states = control_hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( - hidden_states + norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + control_hidden_states ) ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as( + control_hidden_states + ) + conditioning_states = None if self.proj_out is not None: - control_hidden_states = self.proj_out(hidden_states) + conditioning_states = self.proj_out(control_hidden_states) - return hidden_states, control_hidden_states + return conditioning_states, control_hidden_states class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): @@ -309,11 +312,9 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO # 2. Patch embedding hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) - print("hidden_states", hidden_states.shape) control_hidden_states = self.vace_patch_embedding(control_hidden_states) control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) - print("control_hidden_states", control_hidden_states.shape) control_hidden_states_padding = control_hidden_states.new_zeros( batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) ) @@ -333,12 +334,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO if torch.is_grad_enabled() and self.gradient_checkpointing: # Prepare VACE hints control_hidden_states_list = [] - vace_hidden_states = hidden_states for i, block in enumerate(self.vace_blocks): - vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func( - block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb + conditioning_states, control_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb ) - control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i])) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] for i, block in enumerate(self.blocks): @@ -351,12 +351,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO else: # Prepare VACE hints control_hidden_states_list = [] - vace_hidden_states = hidden_states for i, block in enumerate(self.vace_blocks): - vace_hidden_states, control_hidden_states = block( - vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb + conditioning_states, control_hidden_states = block( + hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb ) - control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i])) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] for i, block in enumerate(self.blocks): diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 82e9c05bfe..ae93c57a18 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -23,7 +23,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin -from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...models import AutoencoderKLWan, WanVACETransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -137,7 +137,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanTransformer3DModel, + transformer: WanVACETransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, ): @@ -421,6 +421,13 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match." ) + ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images] + if any(l != ref_images_lengths[0] for l in ref_images_lengths): + raise ValueError( + f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this " + "may be added in the future." + ) + reference_images_preprocessed = [] for i, reference_images_batch in enumerate(reference_images): preprocessed_images = [] @@ -449,7 +456,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): mask: torch.Tensor, reference_images: Optional[List[List[torch.Tensor]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: + device = device or self._execution_device + if isinstance(generator, list): # TODO: support this raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.") @@ -473,8 +483,16 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): vae_dtype = self.vae.dtype video = video.to(dtype=vae_dtype) + latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + if mask is None: latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) + latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype) else: mask = mask.to(dtype=vae_dtype) mask = torch.where(mask > 0.5, 1.0, 0.0) @@ -482,6 +500,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): reactive = video * mask inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax") reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax") + inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype) + reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype) latents = torch.cat([inactive, reactive], dim=1) latent_list = [] @@ -491,6 +511,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): reference_image = reference_image.to(dtype=vae_dtype) reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W] reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax") + reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype) reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1) latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension latent_list.append(latent) @@ -790,7 +811,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): device, ) - conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator) + conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device) mask = self.prepare_masks(mask, reference_images, generator) conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) conditioning_latents = conditioning_latents.to(transformer_dtype) @@ -808,6 +829,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): latents, ) + if conditioning_latents.shape[2] != latents.shape[2]: + logger.warning( + "The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected." + ) + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps)