From 0e97669edb384dc620ac3a4e6c7f0e8c0f843906 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 14 Aug 2025 03:36:10 +0200 Subject: [PATCH] try generating in reverse like... like what seems to be done in original codebase --- .../ltx/pipeline_ltx_condition_infinite.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py index c4bfce9a5a..d7bd24c5d1 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py @@ -974,9 +974,11 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi latent_tile_num_frames = latent_chunk.shape[2] if start_index > 0: - last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) + # last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) + last_latent_chunk = self._select_latents(tile_out_latents, 0, temporal_overlap - 1) + last_latent_chunk = torch.flip(last_latent_chunk, dims=[2]) last_latent_tile_num_frames = last_latent_chunk.shape[2] - latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2) + latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2) total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames last_latent_chunk = self._pack_latents( last_latent_chunk, @@ -993,7 +995,9 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi device=device, ) # conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength - conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 + # conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 + conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength + # conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0 else: total_latent_num_frames = latent_tile_num_frames @@ -1051,14 +1055,14 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk ) latent_model_input = latent_model_input.to(prompt_embeds.dtype) - # Create timestep tensor that has prod(latent_model_input.shape) elements + if start_index == 0: timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) else: timestep = t.view(1, 1).expand((latent_model_input.shape[:-1])).clone() - timestep[:, :last_latent_chunk_num_tokens] = 0.0 - + timestep[:, -last_latent_chunk_num_tokens:] = 0.0 timestep = timestep.float() + # timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() # if start_index > 0: # timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) @@ -1094,7 +1098,8 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi latent_chunk = denoised_latent_chunk else: latent_chunk = torch.cat( - [last_latent_chunk, denoised_latent_chunk[:, last_latent_chunk_num_tokens:]], dim=1 + [denoised_latent_chunk[:, :-last_latent_chunk_num_tokens], last_latent_chunk], + dim=1, ) # tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) # latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk) @@ -1129,7 +1134,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi if start_index == 0: first_tile_out_latents = latent_chunk.clone() else: - latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames:-1, :, :] + latent_chunk = latent_chunk[:, :, 1:-last_latent_tile_num_frames, :, :] latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent( latent_chunk, first_tile_out_latents, adain_factor ) @@ -1140,10 +1145,10 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi # Combine samples t_minus_one = temporal_overlap - 1 parts = [ - tile_out_latents[:, :, :-t_minus_one], - alpha * tile_out_latents[:, :, -t_minus_one:] - + (1 - alpha) * latent_chunk[:, :, :t_minus_one], - latent_chunk[:, :, t_minus_one:], + latent_chunk[:, :, :-t_minus_one], + (1 - alpha) * latent_chunk[:, :, -t_minus_one:] + + alpha * tile_out_latents[:, :, :t_minus_one], + tile_out_latents[:, :, t_minus_one:], ] latent_chunk = torch.cat(parts, dim=2) @@ -1152,7 +1157,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi tile_weights = self._create_spatial_weights( tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap ) - final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights + final_latents[:, :, :, v_start:v_end, h_start:h_end] += tile_out_latents * tile_weights weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights eps = 1e-8