1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

New HunyuanVideo-I2V (#11066)

* update

* update

* update

* add tests

* update docs

* raise value error

* warning for true cfg and guidance scale

* fix test
This commit is contained in:
Aryan
2025-03-24 21:18:40 +05:30
committed by GitHub
parent 5dbe4f5de6
commit 8907a70a36
6 changed files with 562 additions and 44 deletions

View File

@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
| Model name | Description |
|:---|:---|
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
## Quantization

View File

@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": None,
},
"HYVideo-T/2-I2V": {
"HYVideo-T/2-I2V-33ch": {
"in_channels": 16 * 2 + 1,
"out_channels": 16,
"num_attention_heads": 24,
@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "latent_concat",
},
"HYVideo-T/2-I2V-16ch": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "token_replace",
},
}

View File

@@ -27,13 +27,15 @@ from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
PixArtAlphaTextProjection,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -173,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
return gate_msa, gate_mlp
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
token_replace_emb: torch.Tensor,
first_frame_num_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
token_replace_emb = self.linear(self.silu(token_replace_emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
6, dim=1
)
norm_hidden_states = self.norm(hidden_states)
hidden_states_zero = (
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
)
hidden_states_orig = (
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
return (
hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
tr_gate_msa,
tr_shift_mlp,
tr_scale_mlp,
tr_gate_mlp,
)
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
token_replace_emb: torch.Tensor,
first_frame_num_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
token_replace_emb = self.linear(self.silu(token_replace_emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
norm_hidden_states = self.norm(hidden_states)
hidden_states_zero = (
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
)
hidden_states_orig = (
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
return hidden_states, gate_msa, tr_gate_msa
class HunyuanVideoConditionEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
pooled_projection_dim: int,
guidance_embeds: bool,
image_condition_type: Optional[str] = None,
):
super().__init__()
self.image_condition_type = image_condition_type
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
self.guidance_embedder = None
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
token_replace_emb = None
if self.image_condition_type == "token_replace":
token_replace_timestep = torch.zeros_like(timestep)
token_replace_proj = self.time_proj(token_replace_timestep)
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
token_replace_emb = token_replace_emb + pooled_projections
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
conditioning = conditioning + guidance_emb
return conditioning, token_replace_emb
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
@@ -390,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -468,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -503,6 +644,181 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
proj_output = self.proj_out(hidden_states)
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
tr_gate_msa,
tr_shift_mlp,
tr_scale_mlp,
tr_gate_mlp,
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=freqs_cis,
)
# 3. Modulation and residual connection
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -540,6 +856,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
image_condition_type (`str`, *optional*, defaults to `None`):
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
tokens in the latent stream and apply conditioning.
"""
_supports_gradient_checkpointing = True
@@ -570,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
image_condition_type: Optional[str] = None,
) -> None:
super().__init__()
supported_image_condition_types = ["latent_concat", "token_replace"]
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
raise ValueError(
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
)
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
@@ -582,33 +909,52 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
if guidance_embeds:
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
else:
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
if image_condition_type == "token_replace":
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTokenReplaceTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
if image_condition_type == "token_replace":
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoTokenReplaceSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
else:
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
@@ -707,15 +1053,13 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
if self.config.guidance_embeds:
temb = self.time_text_embed(timestep, guidance, pooled_projections)
else:
temb = self.time_text_embed(timestep, pooled_projections)
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
@@ -746,6 +1090,8 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
for block in self.single_transformer_blocks:
@@ -756,17 +1102,31 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
# 5. Output projection

View File

@@ -54,6 +54,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import load_image, export_to_video
>>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
>>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
@@ -69,7 +70,12 @@ EXAMPLE_DOC_STRING = """
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
... )
>>> output = pipe(image=image, prompt=prompt).frames[0]
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
>>> export_to_video(output, "output.mp4", fps=15)
```
"""
@@ -399,7 +405,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
image_embed_interleave: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
image,
@@ -409,6 +416,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device=device,
dtype=dtype,
max_sequence_length=max_sequence_length,
image_embed_interleave=image_embed_interleave,
)
if pooled_prompt_embeds is None:
@@ -433,6 +441,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
prompt_template=None,
true_cfg_scale=1.0,
guidance_scale=1.0,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -471,6 +481,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
)
if true_cfg_scale > 1.0 and guidance_scale > 1.0:
logger.warning(
"Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
"classifier-free guidance and embedded-guidance to be applied. This is not recommended "
"as it may lead to higher memory usage, slower inference and potentially worse results."
)
def prepare_latents(
self,
image: torch.Tensor,
@@ -483,6 +500,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
image_condition_type: str = "latent_concat",
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -497,10 +515,11 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
image = image.unsqueeze(2) # [B, C, 1, H, W]
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
for i in range(batch_size)
]
else:
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
@@ -513,6 +532,9 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
t = torch.tensor([0.999]).to(device=device)
latents = latents * t + image_latents * (1 - t)
if image_condition_type == "token_replace":
image_latents = image_latents[:, :, :1]
return latents, image_latents
def enable_vae_slicing(self):
@@ -598,6 +620,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256,
image_embed_interleave: Optional[int] = None,
):
r"""
The call function to the pipeline for generation.
@@ -704,12 +727,22 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_embeds,
callback_on_step_end_tensor_inputs,
prompt_template,
true_cfg_scale,
guidance_scale,
)
image_condition_type = self.transformer.config.image_condition_type
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
image_embed_interleave = (
image_embed_interleave
if image_embed_interleave is not None
else (
2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1
)
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
@@ -729,7 +762,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
# 3. Prepare latent variables
vae_dtype = self.vae.dtype
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
if image_condition_type == "latent_concat":
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
elif image_condition_type == "token_replace":
num_channels_latents = self.transformer.config.in_channels
latents, image_latents = self.prepare_latents(
image_tensor,
batch_size * num_videos_per_prompt,
@@ -741,10 +779,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device,
generator,
latents,
image_condition_type,
)
image_latents[:, :, 1:] = 0
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
mask[:, :, 1:] = 0
if image_condition_type == "latent_concat":
image_latents[:, :, 1:] = 0
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
mask[:, :, 1:] = 0
# 4. Encode input prompt
transformer_dtype = self.transformer.dtype
@@ -759,6 +799,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_attention_mask=prompt_attention_mask,
device=device,
max_sequence_length=max_sequence_length,
image_embed_interleave=image_embed_interleave,
)
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
@@ -782,10 +823,17 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
# 6. Prepare guidance condition
guidance = None
if self.transformer.config.guidance_embeds:
guidance = (
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -796,16 +844,21 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
continue
self._current_timestep = t
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if image_condition_type == "latent_concat":
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
elif image_condition_type == "token_replace":
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
@@ -817,13 +870,20 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if image_condition_type == "latent_concat":
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
elif image_condition_type == "token_replace":
latents = latents = self.scheduler.step(
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
)[0]
latents = torch.cat([image_latents, latents], dim=2)
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -844,12 +904,16 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
video = video[:, :, 4:, :, :]
if image_condition_type == "latent_concat":
video = video[:, :, 4:, :, :]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents[:, :, 1:, :, :]
if image_condition_type == "latent_concat":
video = latents[:, :, 1:, :, :]
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()

View File

@@ -80,6 +80,7 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -144,6 +145,7 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -209,6 +211,75 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "latent_concat",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 2
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
@property
def input_shape(self):
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 2,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "token_replace",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

View File

@@ -83,6 +83,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
text_embed_dim=16,
pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4),
image_condition_type="latent_concat",
)
torch.manual_seed(0)