1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

try generating in reverse like... like what seems to be done in original codebase

This commit is contained in:
Aryan
2025-08-14 03:36:10 +02:00
parent f2264e813f
commit 0e97669edb

View File

@@ -974,9 +974,11 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
latent_tile_num_frames = latent_chunk.shape[2]
if start_index > 0:
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
# last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
last_latent_chunk = self._select_latents(tile_out_latents, 0, temporal_overlap - 1)
last_latent_chunk = torch.flip(last_latent_chunk, dims=[2])
last_latent_tile_num_frames = last_latent_chunk.shape[2]
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2)
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
last_latent_chunk = self._pack_latents(
last_latent_chunk,
@@ -993,7 +995,9 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
device=device,
)
# conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
# conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength
# conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0
else:
total_latent_num_frames = latent_tile_num_frames
@@ -1051,14 +1055,14 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# Create timestep tensor that has prod(latent_model_input.shape) elements
if start_index == 0:
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1)
else:
timestep = t.view(1, 1).expand((latent_model_input.shape[:-1])).clone()
timestep[:, :last_latent_chunk_num_tokens] = 0.0
timestep[:, -last_latent_chunk_num_tokens:] = 0.0
timestep = timestep.float()
# timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
# if start_index > 0:
# timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
@@ -1094,7 +1098,8 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
latent_chunk = denoised_latent_chunk
else:
latent_chunk = torch.cat(
[last_latent_chunk, denoised_latent_chunk[:, last_latent_chunk_num_tokens:]], dim=1
[denoised_latent_chunk[:, :-last_latent_chunk_num_tokens], last_latent_chunk],
dim=1,
)
# tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
# latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
@@ -1129,7 +1134,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
if start_index == 0:
first_tile_out_latents = latent_chunk.clone()
else:
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames:-1, :, :]
latent_chunk = latent_chunk[:, :, 1:-last_latent_tile_num_frames, :, :]
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(
latent_chunk, first_tile_out_latents, adain_factor
)
@@ -1140,10 +1145,10 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
# Combine samples
t_minus_one = temporal_overlap - 1
parts = [
tile_out_latents[:, :, :-t_minus_one],
alpha * tile_out_latents[:, :, -t_minus_one:]
+ (1 - alpha) * latent_chunk[:, :, :t_minus_one],
latent_chunk[:, :, t_minus_one:],
latent_chunk[:, :, :-t_minus_one],
(1 - alpha) * latent_chunk[:, :, -t_minus_one:]
+ alpha * tile_out_latents[:, :, :t_minus_one],
tile_out_latents[:, :, t_minus_one:],
]
latent_chunk = torch.cat(parts, dim=2)
@@ -1152,7 +1157,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
tile_weights = self._create_spatial_weights(
tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap
)
final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights
final_latents[:, :, :, v_start:v_end, h_start:h_end] += tile_out_latents * tile_weights
weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights
eps = 1e-8