mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
When returning latents, return unpacked and denormalized latents for T2V and I2V
This commit is contained in:
@@ -108,19 +108,6 @@ def main(args):
|
||||
if args.use_video_latents:
|
||||
# Manually convert the audio latents to a waveform
|
||||
audio = audio.to(pipeline.audio_vae.dtype)
|
||||
audio = pipeline._denormalize_audio_latents(
|
||||
audio, pipeline.audio_vae.latents_mean, pipeline.audio_vae.latents_std
|
||||
)
|
||||
|
||||
sampling_rate = pipeline.audio_sampling_rate
|
||||
hop_length = pipeline.audio_hop_length
|
||||
audio_vae_temporal_scale = pipeline.audio_vae_temporal_compression_ratio
|
||||
duration_s = args.num_frames / args.frame_rate
|
||||
latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_vae_temporal_scale)
|
||||
audio_latent_frames = int(duration_s * latents_per_second)
|
||||
latent_mel_bins = pipeline.audio_vae.config.mel_bins // pipeline.audio_vae_mel_compression_ratio
|
||||
audio = pipeline._unpack_audio_latents(audio, audio_latent_frames, latent_mel_bins)
|
||||
|
||||
audio = pipeline.audio_vae.decode(audio, return_dict=False)[0]
|
||||
audio = pipeline.vocoder(audio)
|
||||
|
||||
|
||||
@@ -1070,21 +1070,27 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
)
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
audio = audio_latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
@@ -1109,10 +1115,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
audio_latents = audio_latents.to(self.audio_vae.dtype)
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
)
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
|
||||
audio = self.vocoder(generated_mel_spectrograms)
|
||||
|
||||
|
||||
@@ -1166,21 +1166,27 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
)
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
audio = audio_latents
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
@@ -1205,10 +1211,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
audio_latents = audio_latents.to(self.audio_vae.dtype)
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
)
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
|
||||
audio = self.vocoder(generated_mel_spectrograms)
|
||||
|
||||
|
||||
@@ -154,7 +154,8 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
|
||||
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
return init_latents
|
||||
|
||||
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
|
||||
@@ -275,6 +276,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
latents_normalized: bool = False,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
adain_factor: float = 0.0,
|
||||
@@ -305,6 +307,9 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
|
||||
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
|
||||
latent_channels, latent_frames, latent_height, latent_width)`.
|
||||
latents_normalized (`bool`, *optional*, defaults to `False`)
|
||||
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
|
||||
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
@@ -362,6 +367,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
latents_supplied = latents is not None
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
@@ -376,9 +382,10 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
if latents_supplied and latents_normalized:
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(self.latent_upsampler.dtype)
|
||||
latents_upsampled = self.latent_upsampler(latents)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user