diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 6d25cde071..a69a1a140a 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -278,16 +278,62 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-I2V-14B-720p": + config = { + "model_id": "Wan-AI/Wan2.2-I2V-A14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-T2V-A14B": + config = { + "model_id": "Wan-AI/Wan2.2-T2V-A14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + return config, RENAME_DICT, SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP -def convert_transformer(model_type: str): +def convert_transformer(model_type: str, stage: str=None): config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] model_id = config["model_id"] model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) + if stage is not None: + model_dir = model_dir / stage + original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): @@ -533,7 +579,13 @@ DTYPE_MAPPING = { if __name__ == "__main__": args = get_args() - transformer = convert_transformer(args.model_type) + if "Wan2.2" in args.model_type: + transformer = convert_transformer(args.model_type, stage="high_noise_model") + transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") + else: + transformer = convert_transformer(args.model_type) + transformer_2 = None + vae = convert_vae() text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") @@ -547,7 +599,17 @@ if __name__ == "__main__": dtype = DTYPE_MAPPING[args.dtype] transformer.to(dtype) - if "I2V" in args.model_type or "FLF2V" in args.model_type: + if "Wan2.2" and "I2V" in args.model_type: + pipe = WanImageToVideoPipeline( + transformer=transformer, + transformer_2=transformer_2, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + boundary_ratio=0.9, + ) + elif "I2V" in args.model_type or "FLF2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 ) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index c71138a97d..c2299a2e46 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -149,20 +149,32 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. + In two-stage denoising, `transformer` handles high-noise stages + and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. + When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. + If `None`, only `transformer` is used for the entire denoising process. """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer_2", "image_encoder", "image_processor"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, - image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: CLIPImageProcessor=None, + image_encoder: CLIPVisionModel=None, + transformer_2: WanTransformer3DModel=None, + boundary_ratio: Optional[float] = None, ): super().__init__() @@ -174,7 +186,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): transformer=transformer, scheduler=scheduler, image_processor=image_processor, + transformer_2=transformer_2, ) + self.register_to_config(boundary_ratio=boundary_ratio) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -325,6 +339,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): negative_prompt_embeds=None, image_embeds=None, callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, ): if image is not None and image_embeds is not None: raise ValueError( @@ -368,6 +383,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.config.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + def prepare_latents( self, image: PipelineImageInput, @@ -483,6 +504,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -527,6 +549,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, + uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -589,6 +614,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): negative_prompt_embeds, image_embeds, callback_on_step_end_tensor_inputs, + guidance_scale_2, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -598,7 +624,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -631,13 +662,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) + + if self.config.boundary_ratio is None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -668,16 +701,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, @@ -687,7 +737,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): )[0] if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + noise_uncond = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, @@ -695,7 +745,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]