mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
support wan 2.2 i2v
This commit is contained in:
@@ -278,16 +278,62 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-T2V-A14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-T2V-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
def convert_transformer(model_type: str, stage: str=None):
|
||||
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
|
||||
if stage is not None:
|
||||
model_dir = model_dir / stage
|
||||
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
@@ -533,7 +579,13 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = convert_transformer(args.model_type)
|
||||
if "Wan2.2" in args.model_type:
|
||||
transformer = convert_transformer(args.model_type, stage="high_noise_model")
|
||||
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
|
||||
else:
|
||||
transformer = convert_transformer(args.model_type)
|
||||
transformer_2 = None
|
||||
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
@@ -547,7 +599,17 @@ if __name__ == "__main__":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
if "Wan2.2" and "I2V" in args.model_type:
|
||||
pipe = WanImageToVideoPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.9,
|
||||
)
|
||||
elif "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -149,20 +149,32 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage.
|
||||
In two-stage denoising, `transformer` handles high-noise stages
|
||||
and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`.
|
||||
When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep.
|
||||
If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
image_processor: CLIPImageProcessor=None,
|
||||
image_encoder: CLIPVisionModel=None,
|
||||
transformer_2: WanTransformer3DModel=None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -174,7 +186,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_processor=image_processor,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
@@ -325,6 +339,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
negative_prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if image is not None and image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -368,6 +383,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if self.config.boundary_ratio is not None and image_embeds is not None:
|
||||
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
@@ -483,6 +504,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -527,6 +549,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None,
|
||||
uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -589,6 +614,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
negative_prompt_embeds,
|
||||
image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -598,7 +624,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -631,13 +662,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
else:
|
||||
image_embeds = self.encode_image([image, last_image], device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
if self.config.boundary_ratio is None:
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
else:
|
||||
image_embeds = self.encode_image([image, last_image], device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -668,16 +701,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
|
||||
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -687,7 +737,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
@@ -695,7 +745,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
Reference in New Issue
Block a user