mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[HunyuanVideo1.5] support step-distilled (#12802)
* support step-distilled * style
This commit is contained in:
@@ -69,6 +69,11 @@ TRANSFORMER_CONFIGS = {
|
||||
"target_size": 960,
|
||||
"task_type": "i2v",
|
||||
},
|
||||
"480p_i2v_step_distilled": {
|
||||
"target_size": 640,
|
||||
"task_type": "i2v",
|
||||
"use_meanflow": True,
|
||||
},
|
||||
}
|
||||
|
||||
SCHEDULER_CONFIGS = {
|
||||
@@ -93,6 +98,9 @@ SCHEDULER_CONFIGS = {
|
||||
"720p_i2v_distilled": {
|
||||
"shift": 7.0,
|
||||
},
|
||||
"480p_i2v_step_distilled": {
|
||||
"shift": 7.0,
|
||||
},
|
||||
}
|
||||
|
||||
GUIDANCE_CONFIGS = {
|
||||
@@ -117,6 +125,9 @@ GUIDANCE_CONFIGS = {
|
||||
"720p_i2v_distilled": {
|
||||
"guidance_scale": 1.0,
|
||||
},
|
||||
"480p_i2v_step_distilled": {
|
||||
"guidance_scale": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +137,7 @@ def swap_scale_shift(weight):
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
|
||||
def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None):
|
||||
"""
|
||||
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
|
||||
"""
|
||||
@@ -142,6 +153,20 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")
|
||||
|
||||
if config.use_meanflow:
|
||||
converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
|
||||
"time_r_in.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
|
||||
"time_r_in.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
|
||||
"time_r_in.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
|
||||
"time_r_in.mlp.2.bias"
|
||||
)
|
||||
|
||||
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
|
||||
@@ -627,7 +652,7 @@ def convert_transformer(args):
|
||||
config = TRANSFORMER_CONFIGS[args.transformer_type]
|
||||
with init_empty_weights():
|
||||
transformer = HunyuanVideo15Transformer3DModel(**config)
|
||||
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
|
||||
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=transformer.config)
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
return transformer
|
||||
|
||||
@@ -184,19 +184,32 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
|
||||
The dimension of the output embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
def __init__(self, embedding_dim: int, use_meanflow: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_meanflow = use_meanflow
|
||||
self.time_proj_r = None
|
||||
self.timestep_embedder_r = None
|
||||
if use_meanflow:
|
||||
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
timestep_r: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
|
||||
|
||||
if timestep_r is not None:
|
||||
timesteps_proj_r = self.time_proj_r(timestep_r)
|
||||
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
|
||||
timesteps_emb = timesteps_emb + timesteps_emb_r
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
@@ -567,6 +580,7 @@ class HunyuanVideo15Transformer3DModel(
|
||||
# YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
|
||||
target_size: int = 640, # did not name sample_size since it is in pixel spaces
|
||||
task_type: str = "i2v",
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -582,7 +596,7 @@ class HunyuanVideo15Transformer3DModel(
|
||||
)
|
||||
self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
|
||||
|
||||
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim)
|
||||
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow=use_meanflow)
|
||||
|
||||
self.cond_type_embed = nn.Embedding(3, inner_dim)
|
||||
|
||||
@@ -612,6 +626,7 @@ class HunyuanVideo15Transformer3DModel(
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
timestep_r: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask_2: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
@@ -643,7 +658,7 @@ class HunyuanVideo15Transformer3DModel(
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
temb = self.time_embed(timestep)
|
||||
temb = self.time_embed(timestep, timestep_r=timestep_r)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
|
||||
@@ -852,6 +852,15 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
|
||||
if self.transformer.config.use_meanflow:
|
||||
if i == len(timesteps) - 1:
|
||||
timestep_r = torch.tensor([0.0], device=device)
|
||||
else:
|
||||
timestep_r = timesteps[i + 1]
|
||||
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
|
||||
else:
|
||||
timestep_r = None
|
||||
|
||||
# Step 1: Collect model inputs needed for the guidance method
|
||||
# conditional inputs should always be first element in the tuple
|
||||
guider_inputs = {
|
||||
@@ -893,6 +902,7 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
|
||||
hidden_states=latent_model_input,
|
||||
image_embeds=image_embeds,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep_r,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user