mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
New HunyuanVideo-I2V (#11066)
* update * update * update * add tests * update docs * raise value error * warning for true cfg and guidance scale * fix test
This commit is contained in:
@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
|
||||
## Quantization
|
||||
|
||||
|
||||
@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": None,
|
||||
},
|
||||
"HYVideo-T/2-I2V": {
|
||||
"HYVideo-T/2-I2V-33ch": {
|
||||
"in_channels": 16 * 2 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "latent_concat",
|
||||
},
|
||||
"HYVideo-T/2-I2V-16ch": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 20,
|
||||
"num_single_layers": 40,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 4096,
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"image_condition_type": "token_replace",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -27,13 +27,15 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
PixArtAlphaTextProjection,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -173,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
token_replace_emb: torch.Tensor,
|
||||
first_frame_num_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
|
||||
6, dim=1
|
||||
)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states_zero = (
|
||||
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
||||
)
|
||||
hidden_states_orig = (
|
||||
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
tr_gate_msa,
|
||||
tr_shift_mlp,
|
||||
tr_scale_mlp,
|
||||
tr_gate_mlp,
|
||||
)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
token_replace_emb: torch.Tensor,
|
||||
first_frame_num_tokens: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
token_replace_emb = self.linear(self.silu(token_replace_emb))
|
||||
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states_zero = (
|
||||
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
|
||||
)
|
||||
hidden_states_orig = (
|
||||
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
|
||||
return hidden_states, gate_msa, tr_gate_msa
|
||||
|
||||
|
||||
class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
pooled_projection_dim: int,
|
||||
guidance_embeds: bool,
|
||||
image_condition_type: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.image_condition_type = image_condition_type
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
self.guidance_embedder = None
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
token_replace_emb = None
|
||||
if self.image_condition_type == "token_replace":
|
||||
token_replace_timestep = torch.zeros_like(timestep)
|
||||
token_replace_proj = self.time_proj(token_replace_timestep)
|
||||
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
|
||||
token_replace_emb = token_replace_emb + pooled_projections
|
||||
|
||||
if self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
||||
conditioning = conditioning + guidance_emb
|
||||
|
||||
return conditioning, token_replace_emb
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -390,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
@@ -468,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
@@ -503,6 +644,181 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
mlp_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
||||
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states = (
|
||||
norm_hidden_states[:, :-text_seq_length, :],
|
||||
norm_hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
|
||||
# 2. Attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
|
||||
proj_output = self.proj_out(hidden_states)
|
||||
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
|
||||
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, :-text_seq_length, :],
|
||||
hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
token_replace_emb: torch.Tensor = None,
|
||||
num_tokens: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
tr_gate_msa,
|
||||
tr_shift_mlp,
|
||||
tr_scale_mlp,
|
||||
tr_gate_mlp,
|
||||
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
|
||||
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
|
||||
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
|
||||
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
|
||||
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
@@ -540,6 +856,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
image_condition_type (`str`, *optional*, defaults to `None`):
|
||||
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
|
||||
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
|
||||
tokens in the latent stream and apply conditioning.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -570,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_image_condition_types = ["latent_concat", "token_replace"]
|
||||
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
|
||||
raise ValueError(
|
||||
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
@@ -582,33 +909,52 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
if guidance_embeds:
|
||||
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
else:
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
||||
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
||||
)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if image_condition_type == "token_replace":
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTokenReplaceTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
if image_condition_type == "token_replace":
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -707,15 +1053,13 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
if self.config.guidance_embeds:
|
||||
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
else:
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
@@ -746,6 +1090,8 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
@@ -756,17 +1102,31 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
token_replace_emb,
|
||||
first_frame_num_tokens,
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
|
||||
@@ -54,6 +54,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
>>> from diffusers.utils import load_image, export_to_video
|
||||
|
||||
>>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
@@ -69,7 +70,12 @@ EXAMPLE_DOC_STRING = """
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
|
||||
... )
|
||||
|
||||
>>> output = pipe(image=image, prompt=prompt).frames[0]
|
||||
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V
|
||||
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
|
||||
|
||||
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
|
||||
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
|
||||
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
@@ -399,7 +405,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
image_embed_interleave: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
||||
image,
|
||||
@@ -409,6 +416,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_sequence_length=max_sequence_length,
|
||||
image_embed_interleave=image_embed_interleave,
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
@@ -433,6 +441,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_template=None,
|
||||
true_cfg_scale=1.0,
|
||||
guidance_scale=1.0,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
@@ -471,6 +481,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
||||
)
|
||||
|
||||
if true_cfg_scale > 1.0 and guidance_scale > 1.0:
|
||||
logger.warning(
|
||||
"Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
|
||||
"classifier-free guidance and embedded-guidance to be applied. This is not recommended "
|
||||
"as it may lead to higher memory usage, slower inference and potentially worse results."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
@@ -483,6 +500,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
image_condition_type: str = "latent_concat",
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -497,10 +515,11 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
image = image.unsqueeze(2) # [B, C, 1, H, W]
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
||||
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
|
||||
|
||||
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
|
||||
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
|
||||
@@ -513,6 +532,9 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
t = torch.tensor([0.999]).to(device=device)
|
||||
latents = latents * t + image_latents * (1 - t)
|
||||
|
||||
if image_condition_type == "token_replace":
|
||||
image_latents = image_latents[:, :, :1]
|
||||
|
||||
return latents, image_latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
@@ -598,6 +620,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
||||
max_sequence_length: int = 256,
|
||||
image_embed_interleave: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -704,12 +727,22 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_template,
|
||||
true_cfg_scale,
|
||||
guidance_scale,
|
||||
)
|
||||
|
||||
image_condition_type = self.transformer.config.image_condition_type
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
image_embed_interleave = (
|
||||
image_embed_interleave
|
||||
if image_embed_interleave is not None
|
||||
else (
|
||||
2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1
|
||||
)
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
@@ -729,7 +762,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
# 3. Prepare latent variables
|
||||
vae_dtype = self.vae.dtype
|
||||
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
|
||||
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
|
||||
|
||||
if image_condition_type == "latent_concat":
|
||||
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
|
||||
elif image_condition_type == "token_replace":
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
|
||||
latents, image_latents = self.prepare_latents(
|
||||
image_tensor,
|
||||
batch_size * num_videos_per_prompt,
|
||||
@@ -741,10 +779,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image_condition_type,
|
||||
)
|
||||
image_latents[:, :, 1:] = 0
|
||||
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
|
||||
mask[:, :, 1:] = 0
|
||||
if image_condition_type == "latent_concat":
|
||||
image_latents[:, :, 1:] = 0
|
||||
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
|
||||
mask[:, :, 1:] = 0
|
||||
|
||||
# 4. Encode input prompt
|
||||
transformer_dtype = self.transformer.dtype
|
||||
@@ -759,6 +799,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
image_embed_interleave=image_embed_interleave,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
||||
@@ -782,10 +823,17 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
|
||||
# 6. Prepare guidance condition
|
||||
guidance = None
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
@@ -796,16 +844,21 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if image_condition_type == "latent_concat":
|
||||
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
|
||||
elif image_condition_type == "token_replace":
|
||||
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -817,13 +870,20 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
if image_condition_type == "latent_concat":
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
elif image_condition_type == "token_replace":
|
||||
latents = latents = self.scheduler.step(
|
||||
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
|
||||
)[0]
|
||||
latents = torch.cat([image_latents, latents], dim=2)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -844,12 +904,16 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = video[:, :, 4:, :, :]
|
||||
if image_condition_type == "latent_concat":
|
||||
video = video[:, :, 4:, :, :]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents[:, :, 1:, :, :]
|
||||
if image_condition_type == "latent_concat":
|
||||
video = latents[:, :, 1:, :, :]
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -80,6 +80,7 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -144,6 +145,7 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -209,6 +211,75 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": "latent_concat",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(1, *self.output_shape))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 2
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (8, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 2,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 10,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"num_refiner_layers": 1,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": "token_replace",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@@ -83,6 +83,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
text_embed_dim=16,
|
||||
pooled_projection_dim=8,
|
||||
rope_axes_dim=(2, 4, 4),
|
||||
image_condition_type="latent_concat",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user