1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

When returning latents, return unpacked and denormalized latents for T2V and I2V

This commit is contained in:
Daniel Gu
2026-01-07 09:04:34 +01:00
parent e6e7e7b26f
commit aa9b65d0fc
4 changed files with 45 additions and 47 deletions

View File

@@ -108,19 +108,6 @@ def main(args):
if args.use_video_latents:
# Manually convert the audio latents to a waveform
audio = audio.to(pipeline.audio_vae.dtype)
audio = pipeline._denormalize_audio_latents(
audio, pipeline.audio_vae.latents_mean, pipeline.audio_vae.latents_std
)
sampling_rate = pipeline.audio_sampling_rate
hop_length = pipeline.audio_hop_length
audio_vae_temporal_scale = pipeline.audio_vae_temporal_compression_ratio
duration_s = args.num_frames / args.frame_rate
latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_vae_temporal_scale)
audio_latent_frames = int(duration_s * latents_per_second)
latent_mel_bins = pipeline.audio_vae.config.mel_bins // pipeline.audio_vae_mel_compression_ratio
audio = pipeline._unpack_audio_latents(audio, audio_latent_frames, latent_mel_bins)
audio = pipeline.audio_vae.decode(audio, return_dict=False)[0]
audio = pipeline.vocoder(audio)

View File

@@ -1070,21 +1070,27 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
if XLA_AVAILABLE:
xm.mark_step()
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
video = latents
audio = audio_latents
else:
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
if not self.vae.config.timestep_conditioning:
@@ -1109,10 +1115,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
video = self.video_processor.postprocess_video(video, output_type=output_type)
audio_latents = audio_latents.to(self.audio_vae.dtype)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
audio = self.vocoder(generated_mel_spectrograms)

View File

@@ -1166,21 +1166,27 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
if XLA_AVAILABLE:
xm.mark_step()
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
video = latents
audio = audio_latents
else:
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
if not self.vae.config.timestep_conditioning:
@@ -1205,10 +1211,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
video = self.video_processor.postprocess_video(video, output_type=output_type)
audio_latents = audio_latents.to(self.audio_vae.dtype)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
audio = self.vocoder(generated_mel_spectrograms)

View File

@@ -154,7 +154,8 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
@@ -275,6 +276,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
latents: Optional[torch.Tensor] = None,
latents_normalized: bool = False,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
@@ -305,6 +307,9 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
latent_channels, latent_frames, latent_height, latent_width)`.
latents_normalized (`bool`, *optional*, defaults to `False`)
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
@@ -362,6 +367,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=torch.float32)
latents_supplied = latents is not None
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
@@ -376,9 +382,10 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
latents=latents,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
if latents_supplied and latents_normalized:
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
latents_upsampled = self.latent_upsampler(latents)