1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Infer latent dims if latents/audio_latents is supplied

This commit is contained in:
Daniel Gu
2026-01-14 03:09:53 +01:00
parent ce5a51430b
commit f4d47b9cec
2 changed files with 61 additions and 43 deletions

View File

@@ -682,32 +682,23 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)
if latents is not None:
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
return latents.to(device=device, dtype=dtype), latent_length
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -717,7 +708,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents
@property
def guidance_scale(self):
@@ -935,6 +926,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
else:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
num_channels_latents = self.transformer.config.in_channels
@@ -950,20 +949,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents,
)
duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
else:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,

View File

@@ -742,32 +742,23 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
self,
batch_size: int = 1,
num_channels_latents: int = 8,
audio_latent_length: int = 1, # 1 is just a dummy value
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
duration_s = num_frames / frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
)
latent_length = round(duration_s * latents_per_second)
if latents is not None:
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
return latents.to(device=device, dtype=dtype), latent_length
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -777,7 +768,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
return latents, latent_length
return latents
@property
def guidance_scale(self):
@@ -995,6 +986,19 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
# 4. Prepare latent variables
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
if latents is not None:
if latents.ndim == 5:
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
else:
logger.warning(
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=prompt_embeds.dtype)
@@ -1015,20 +1019,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
if self.do_classifier_free_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
duration_s = num_frames / frame_rate
audio_latents_per_second = (
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
)
audio_num_frames = round(duration_s * audio_latents_per_second)
if audio_latents is not None:
if audio_latents.ndim == 4:
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
else:
logger.warning(
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,
@@ -1036,11 +1050,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,