diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py index d76458403c..3dcf02ce61 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py @@ -979,6 +979,11 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi last_latent_tile_num_frames = last_latent_chunk.shape[2] latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2) total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames + + conditioning_mask = torch.zeros( + (batch_size, total_latent_num_frames), dtype=torch.float32, device=device, + ) + conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 else: total_latent_num_frames = latent_tile_num_frames @@ -998,6 +1003,21 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi device=device, ) + if start_index > 0: + conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0]) + + video_ids = self._scale_video_ids( + video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=0, + device=device + ) + video_ids = video_ids.float() + video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate) + if self.do_classifier_free_guidance: + video_ids = torch.cat([video_ids, video_ids], dim=0) + # Set timesteps inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) sigmas = self.scheduler.sigmas @@ -1005,18 +1025,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi self._num_timesteps = len(inner_timesteps) if start_index == 0: - video_ids = self._scale_video_ids( - video_ids, - scale_factor=self.vae_spatial_compression_ratio, - scale_factor_t=self.vae_temporal_compression_ratio, - frame_index=0, - device=device - ) - video_ids = video_ids.float() - video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate) - if self.do_classifier_free_guidance: - video_ids = torch.cat([video_ids, video_ids], dim=0) - with self.progress_bar(total=inner_num_inference_steps) as progress_bar: for i, t in enumerate(inner_timesteps): if self.interrupt: @@ -1079,24 +1087,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi ) first_tile_out_latents = tile_out_latents.clone() else: - conditioning_mask = torch.zeros( - (batch_size, total_latent_num_frames), dtype=torch.float32, device=device, - ) - conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 - conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0]) - - video_ids = self._scale_video_ids( - video_ids, - scale_factor=self.vae_spatial_compression_ratio, - scale_factor_t=self.vae_temporal_compression_ratio, - frame_index=0, - device=device - ) - video_ids = video_ids.float() - video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate) - if self.do_classifier_free_guidance: - video_ids = torch.cat([video_ids, video_ids], dim=0) - with self.progress_bar(total=inner_num_inference_steps) as progress_bar: for i, t in enumerate(inner_timesteps): if self.interrupt: