diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 588e05737e..54f1061da5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -682,32 +682,23 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) - return latents.to(device=device, dtype=dtype), latent_length + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -717,7 +708,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -935,6 +926,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -950,20 +949,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents, ) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 92206cee4e..460ff8eec7 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -742,32 +742,23 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) - return latents.to(device=device, dtype=dtype), latent_length + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -777,7 +768,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -995,6 +986,19 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=prompt_embeds.dtype) @@ -1015,20 +1019,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -1036,11 +1050,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length,