diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py index 0772ed0fa0..d76458403c 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py @@ -973,6 +973,30 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi ): latent_chunk = self._select_latents(tile_latents, start_index, min(end_index - 1, tile_latents.shape[2] - 1)) latent_tile_num_frames = latent_chunk.shape[2] + + if start_index > 0: + last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) + last_latent_tile_num_frames = last_latent_chunk.shape[2] + latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2) + total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames + else: + total_latent_num_frames = latent_tile_num_frames + + latent_chunk = self._pack_latents( + latent_chunk, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + video_ids = self._prepare_video_ids( + batch_size, + total_latent_num_frames, + latent_tile_height, + latent_tile_width, + patch_size_t=self.transformer_temporal_patch_size, + patch_size=self.transformer_spatial_patch_size, + device=device, + ) # Set timesteps inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -981,17 +1005,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi self._num_timesteps = len(inner_timesteps) if start_index == 0: - latent_chunk = self._pack_latents(latent_chunk, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) - - video_ids = self._prepare_video_ids( - batch_size, - latent_tile_num_frames, - latent_tile_height, - latent_tile_width, - patch_size_t=self.transformer_temporal_patch_size, - patch_size=self.transformer_spatial_patch_size, - device=device, - ) video_ids = self._scale_video_ids( video_ids, scale_factor=self.vae_spatial_compression_ratio, @@ -1066,26 +1079,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi ) first_tile_out_latents = tile_out_latents.clone() else: - last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) - last_latent_tile_num_frames = last_latent_chunk.shape[2] - latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2) - total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames - latent_chunk = self._pack_latents( - latent_chunk, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - - video_ids = self._prepare_video_ids( - batch_size, - total_latent_num_frames, - latent_tile_height, - latent_tile_width, - patch_size_t=self.transformer_temporal_patch_size, - patch_size=self.transformer_spatial_patch_size, - device=device, - ) - conditioning_mask = torch.zeros( (batch_size, total_latent_num_frames), dtype=torch.float32, device=device, )