diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 34c05424d5..a5ad844776 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -175,7 +175,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO """ _supports_gradient_checkpointing = True - _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -273,9 +273,6 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if control_hidden_states is None: - raise ValueError("Control hidden states must be provided for VACE models.") - if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -299,6 +296,12 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO if control_hidden_states_scale is None: control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) + control_hidden_states_scale = torch.unbind(control_hidden_states_scale) + if len(control_hidden_states_scale) != len(self.config.vace_layers): + raise ValueError( + f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be " + f"equal to {len(self.config.vace_layers)}." + ) # 1. Rotary position embedding rotary_emb = self.rope(hidden_states) @@ -306,9 +309,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO # 2. Patch embedding hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + print("hidden_states", hidden_states.shape) control_hidden_states = self.vace_patch_embedding(control_hidden_states) control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) + print("control_hidden_states", control_hidden_states.shape) control_hidden_states_padding = control_hidden_states.new_zeros( batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) ) @@ -329,11 +334,11 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO # Prepare VACE hints control_hidden_states_list = [] vace_hidden_states = hidden_states - for block in self.vace_blocks: + for i, block in enumerate(self.vace_blocks): vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func( block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb ) - control_hidden_states_list.append(control_hidden_states) + control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] for i, block in enumerate(self.blocks): @@ -341,24 +346,24 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) if i in self.config.vace_layers: - control_hint = control_hidden_states_list.pop() - hidden_states = hidden_states + control_hint * control_hidden_states_scale[i] + control_hint, scale = control_hidden_states_list.pop() + hidden_states = hidden_states + control_hint * scale else: # Prepare VACE hints control_hidden_states_list = [] vace_hidden_states = hidden_states - for block in self.vace_blocks: + for i, block in enumerate(self.vace_blocks): vace_hidden_states, control_hidden_states = block( vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb ) - control_hidden_states_list.append(control_hidden_states) + control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] for i, block in enumerate(self.blocks): hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) if i in self.config.vace_layers: - control_hint = control_hidden_states_list.pop() - hidden_states = hidden_states + control_hint * control_hidden_states_scale[i] + control_hint, scale = control_hidden_states_list.pop() + hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 235e247cba..82e9c05bfe 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -292,7 +292,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): mask=None, reference_images=None, ): - base = self.vae_scale_factor_spatial * self.transformer.config.patch_size + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] if height % base != 0 or width % base != 0: raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") @@ -368,39 +368,78 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): device: Optional[torch.device] = None, ): if video is not None: - video = self.video_processor.preprocess_video(video, None, None) # Use the height/width of video - image_size = tuple(video.shape[-2:]) + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + video_height, video_width = self.video_processor.get_default_height_width(video[0]) + + if video_height * video_width > height * width: + scale = min(width / video_width, height / video_height) + video_height, video_width = int(video_height * scale), int(video_width * scale) + + if video_height % base != 0 or video_width % base != 0: + logger.warning( + f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. " + ) + video_height = (video_height // base) * base + video_width = (video_width // base) * base + + assert video_height * video_width <= height * width + + video = self.video_processor.preprocess_video(video, video_height, video_width) + image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling) else: - video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=dtype, device=device) + video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device) image_size = (height, width) # Use the height/width provider by user if mask is not None: - mask = self.video_processor.preprocess_video(mask, height, width) + mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1]) else: - mask = torch.ones_like(video, dtype=dtype, device=device) + mask = torch.ones_like(video) video = video.to(dtype=dtype, device=device) mask = mask.to(dtype=dtype, device=device) - reference_images_preprocessed = [] - if reference_images is not None: - if not isinstance(reference_images, list): - reference_images = [reference_images] - for i, image in enumerate(reference_images): - image = self.video_processor.preprocess(image, None, None) # Use the height/width of image + # Make a list of list of images where the outer list corresponds to video batch size and the inner list + # corresponds to list of conditioning images per video + if reference_images is None or isinstance(reference_images, PIL.Image.Image): + reference_images = [[reference_images] for _ in range(video.shape[0])] + elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image): + reference_images = [reference_images] + elif ( + isinstance(reference_images, (list, tuple)) + and isinstance(next(iter(reference_images)), list) + and isinstance(next(iter(reference_images[0])), PIL.Image.Image) + ): + reference_images = reference_images + else: + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match." + ) + + reference_images_preprocessed = [] + for i, reference_images_batch in enumerate(reference_images): + preprocessed_images = [] + for j, image in enumerate(reference_images_batch): + if image is None: + continue + image = self.video_processor.preprocess(image, None, None) img_height, img_width = image.shape[-2:] scale = min(image_size[0] / img_height, image_size[1] / img_width) new_height, new_width = int(img_height * scale), int(img_width * scale) resized_image = torch.nn.functional.interpolate( - image.unsqueeze(1), size=(new_height, new_width), mode="bilinear", align_corners=False - ).squeeze(1) - + image, size=(new_height, new_width), mode="bilinear", align_corners=False + ).squeeze(0) # [C, H, W] top = (image_size[0] - new_height) // 2 left = (image_size[1] - new_width) // 2 - canvas = torch.ones(batch_size, 1, 3, *image_size, device=device, dtype=dtype) - canvas[:, :, :, top : top + new_height, left : left + new_width] = resized_image - reference_images_preprocessed.append(canvas) + canvas = torch.ones(3, *image_size, device=device, dtype=dtype) + canvas[:, top : top + new_height, left : left + new_width] = resized_image + preprocessed_images.append(canvas) + reference_images_preprocessed.append(preprocessed_images) return video, mask, reference_images_preprocessed @@ -408,7 +447,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): self, video: torch.Tensor, mask: torch.Tensor, - reference_images: Optional[List[torch.Tensor]] = None, + reference_images: Optional[List[List[torch.Tensor]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, ) -> torch.Tensor: if isinstance(generator, list): @@ -416,7 +455,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.") if reference_images is None: - # For each batch of video, we set no reference image (as one or more can be passed by user) + # For each batch of video, we set no re + # ference image (as one or more can be passed by user) reference_images = [[None] for _ in range(video.shape[0])] else: if video.shape[0] != len(reference_images): @@ -437,22 +477,24 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) else: mask = mask.to(dtype=vae_dtype) - mask = [torch.where(m > 0.5, 1.0, 0.0) for m in mask] - inactive = [v * (1 - m) for v, m in zip(video, mask)] - reactive = [v * m for v, m in zip(video, mask)] + mask = torch.where(mask > 0.5, 1.0, 0.0) + inactive = video * (1 - mask) + reactive = video * mask inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax") reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax") - latents = [torch.cat([i, r], dim=0) for i, r in zip(inactive, reactive)] + latents = torch.cat([inactive, reactive], dim=1) latent_list = [] - for latent, ref_images in zip(latents, reference_images): - if ref_images is not None: - ref_images = ref_images.to(dtype=vae_dtype) - ref_latents = retrieve_latents(self.vae.encode(ref_images), generator, sample_mode="argmax") - ref_latents = [torch.cat([r, torch.zeros_like(r)], dim=0) for r in ref_latents] - latent = torch.cat([*ref_latents, latent], dim=1) + for latent, reference_images_batch in zip(latents, reference_images): + for reference_image in reference_images_batch: + assert reference_image.ndim == 3 + reference_image = reference_image.to(dtype=vae_dtype) + reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W] + reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax") + reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1) + latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension latent_list.append(latent) - return latent_list + return torch.stack(latent_list) def prepare_masks( self, @@ -479,25 +521,28 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): "Generating with more than one video is not yet supported. This may be supported in the future." ) + transformer_patch_size = self.transformer.config.patch_size[1] + mask_list = [] - transformer_patch_size = self.transformer.config.patch_size - for mask_, ref_images in zip(mask, reference_images): - num_frames, num_channels, height, width = mask_.shape + for mask_, reference_images_batch in zip(mask, reference_images): + num_channels, num_frames, height, width = mask_.shape new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size - mask_ = mask_[:, 0, :, :] - mask_ = mask_.view(num_frames, height, self.vae_scale_factor_spatial, width, self.vae_scale_factor_spatial) - mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(2, 4).flatten(0, 1) + mask_ = mask_[0, :, :, :] + mask_ = mask_.view( + num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial + ) + mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width] mask_ = torch.nn.functional.interpolate( mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" ).squeeze(0) - if ref_images is not None: - num_ref_images = ref_images.size(0) - mask_padding = torch.zeros_like(mask[:num_ref_images, :, :, :]) + num_ref_images = len(reference_images_batch) + if num_ref_images > 0: + mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :]) mask_ = torch.cat([mask_, mask_padding], dim=1) mask_list.append(mask_) - return mask_list + return torch.stack(mask_list) def prepare_latents( self, @@ -746,12 +791,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): ) conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator) - conditioning_latents = [c.to(transformer_dtype) for c in conditioning_latents] - mask = self.prepare_masks(mask, reference_images, generator) - mask = [m.to(transformer_dtype) for m in mask] - - conditioning_latents = [torch.cat([c, m], dim=1) for c, m in zip(conditioning_latents, mask)] + conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) + conditioning_latents = conditioning_latents.to(transformer_dtype) num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents(