mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add support for I2V (#8)
* start i2v. * up * up * up * up * up * remove uniform strategy code. * remove unneeded code.
This commit is contained in:
@@ -538,6 +538,7 @@ else:
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
@@ -1245,6 +1246,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
|
||||
@@ -1051,6 +1051,7 @@ class LTX2VideoTransformer3DModel(
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
audio_timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
@@ -1073,8 +1074,7 @@ class LTX2VideoTransformer3DModel(
|
||||
Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels).
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Input text embeddings of shape TODO.
|
||||
timesteps (`torch.Tensor`):
|
||||
Timestep information of shape (batch_size, num_train_timesteps).
|
||||
TODO for the rest.
|
||||
|
||||
Returns:
|
||||
`AudioVisualModelOutput` or `tuple`:
|
||||
@@ -1097,6 +1097,9 @@ class LTX2VideoTransformer3DModel(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
@@ -1143,7 +1146,7 @@ class LTX2VideoTransformer3DModel(
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
temb_audio, audio_embedded_timestep = self.audio_time_embed(
|
||||
timestep.flatten(),
|
||||
audio_timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
@@ -1165,12 +1168,12 @@ class LTX2VideoTransformer3DModel(
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
timestep.flatten(),
|
||||
audio_timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
|
||||
@@ -288,7 +288,7 @@ else:
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx2 import LTX2Pipeline
|
||||
from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .text_encoder import LTX2AudioVisualTextEncoder
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
@@ -496,16 +496,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
|
||||
1138
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
1138
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user