diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py index d7bd24c5d1..c4bfce9a5a 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py @@ -974,11 +974,9 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi latent_tile_num_frames = latent_chunk.shape[2] if start_index > 0: - # last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) - last_latent_chunk = self._select_latents(tile_out_latents, 0, temporal_overlap - 1) - last_latent_chunk = torch.flip(last_latent_chunk, dims=[2]) + last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) last_latent_tile_num_frames = last_latent_chunk.shape[2] - latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2) + latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2) total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames last_latent_chunk = self._pack_latents( last_latent_chunk, @@ -995,9 +993,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi device=device, ) # conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength - # conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 - conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength - # conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0 + conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 else: total_latent_num_frames = latent_tile_num_frames @@ -1055,14 +1051,14 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk ) latent_model_input = latent_model_input.to(prompt_embeds.dtype) - + # Create timestep tensor that has prod(latent_model_input.shape) elements if start_index == 0: timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) else: timestep = t.view(1, 1).expand((latent_model_input.shape[:-1])).clone() - timestep[:, -last_latent_chunk_num_tokens:] = 0.0 - timestep = timestep.float() + timestep[:, :last_latent_chunk_num_tokens] = 0.0 + timestep = timestep.float() # timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() # if start_index > 0: # timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) @@ -1098,8 +1094,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi latent_chunk = denoised_latent_chunk else: latent_chunk = torch.cat( - [denoised_latent_chunk[:, :-last_latent_chunk_num_tokens], last_latent_chunk], - dim=1, + [last_latent_chunk, denoised_latent_chunk[:, last_latent_chunk_num_tokens:]], dim=1 ) # tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) # latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk) @@ -1134,7 +1129,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi if start_index == 0: first_tile_out_latents = latent_chunk.clone() else: - latent_chunk = latent_chunk[:, :, 1:-last_latent_tile_num_frames, :, :] + latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames:-1, :, :] latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent( latent_chunk, first_tile_out_latents, adain_factor ) @@ -1145,10 +1140,10 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi # Combine samples t_minus_one = temporal_overlap - 1 parts = [ - latent_chunk[:, :, :-t_minus_one], - (1 - alpha) * latent_chunk[:, :, -t_minus_one:] - + alpha * tile_out_latents[:, :, :t_minus_one], - tile_out_latents[:, :, t_minus_one:], + tile_out_latents[:, :, :-t_minus_one], + alpha * tile_out_latents[:, :, -t_minus_one:] + + (1 - alpha) * latent_chunk[:, :, :t_minus_one], + latent_chunk[:, :, t_minus_one:], ] latent_chunk = torch.cat(parts, dim=2) @@ -1157,7 +1152,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi tile_weights = self._create_spatial_weights( tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap ) - final_latents[:, :, :, v_start:v_end, h_start:h_end] += tile_out_latents * tile_weights + final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights eps = 1e-8