diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index bf3f3d13c5..bc2559ebbc 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -761,11 +761,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): """ # 1. Generate coordinates in the frame (time) dimension. - audio_duration_s = num_frames / fps - latent_frames = int(audio_duration_s * self.audio_latents_per_second) # Always compute rope in fp32 grid_f = torch.arange( - start=shift, end=latent_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device ) # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index e95b8d5c0b..cbfb5b5c4a 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -689,7 +689,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents_per_second = ( float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) ) - latent_length = int(duration_s * latents_per_second) + latent_length = round(duration_s * latents_per_second) if latents is not None: return latents.to(device=device, dtype=dtype), latent_length diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 5a4b272809..652955fee1 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -749,7 +749,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents_per_second = ( float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) ) - latent_length = int(duration_s * latents_per_second) + latent_length = round(duration_s * latents_per_second) if latents is not None: return latents.to(device=device, dtype=dtype), latent_length