1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

this version kinda works but results still bad

This commit is contained in:
Aryan
2025-08-13 22:06:31 +02:00
parent 16b82a546f
commit 27a43451cb

View File

@@ -396,7 +396,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
negative_prompt_attention_mask=None,
temporal_tile_size=None,
temporal_overlap=None,
temporal_overlap_cond_strength=None,
horizontal_tiles=None,
vertical_tiles=None,
spatial_overlap=None,
@@ -449,10 +448,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
)
if temporal_overlap < 16 or temporal_overlap > 80 or temporal_overlap % 8 != 0:
raise ValueError(f"`temporal_overlap` must be in [16, 80] and divisible by 8 but is {temporal_overlap}.")
if not (0.0 <= temporal_overlap_cond_strength <= 1.0):
raise ValueError(
f"`temporal_overlap_cond_strength` must be in [0.0, 1.0] but is {temporal_overlap_cond_strength}."
)
if not (1 <= horizontal_tiles <= 6):
raise ValueError(f"`horizontal_tiles` must be between 1 and 6 but is {horizontal_tiles}.")
if not (1 <= vertical_tiles <= 6):
@@ -573,7 +568,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
end_idx = max(0, min(end_idx, num_frames - 1))
if start_idx > end_idx:
start_idx = min(start_idx, end_idx)
latents = latents[:, :, start_idx : end_idx + 1, :, :]
latents = latents[:, :, start_idx : end_idx + 1, :, :].clone()
return latents
@staticmethod
@@ -742,7 +737,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
timesteps: List[int] = None,
guidance_scale: float = 3,
guidance_rescale: float = 0.0,
image_cond_noise_scale: float = 0.15,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -762,7 +756,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
max_sequence_length: int = 256,
temporal_tile_size: int = 80,
temporal_overlap: int = 24,
temporal_overlap_cond_strength: float = 0.5,
horizontal_tiles: int = 1,
vertical_tiles: int = 1,
spatial_overlap: int = 1,
@@ -846,8 +839,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
region.
temporal_overlap (`int`, defaults to `24`):
The overlap between the temporal tiles, in pixel frames.
temporal_overlap_cond_strength (`float`, defaults to `0.5`):
The strength of the conditioning on the latents from the previous temporal tile.
horizontal_tiles (`int`, defaults to `1`):
Number of horizontal spatial tiles to use for the sampling.
vertical_tiles (`int`, defaults to `1`):
@@ -878,7 +869,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
negative_prompt_attention_mask=negative_prompt_attention_mask,
temporal_tile_size=temporal_tile_size,
temporal_overlap=temporal_overlap,
temporal_overlap_cond_strength=temporal_overlap_cond_strength,
horizontal_tiles=horizontal_tiles,
vertical_tiles=vertical_tiles,
spatial_overlap=spatial_overlap,
@@ -1076,8 +1066,8 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
else:
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
last_latent_tile_num_frames = last_latent_chunk.shape[2]
latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2)
total_latent_num_frames = latent_tile_num_frames + last_latent_tile_num_frames
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
latent_chunk = self._pack_latents(
latent_chunk,
self.transformer_spatial_patch_size,
@@ -1094,10 +1084,10 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
device=device,
)
conditioning_mask = torch.ones(
conditioning_mask = torch.zeros(
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
)
conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
video_ids = self._scale_video_ids(
@@ -1120,7 +1110,13 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
self._current_timestep = t
latent_model_input = torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
conditioning_mask_model_input = (
torch.cat([conditioning_mask, conditioning_mask])
if self.do_classifier_free_guidance
else conditioning_mask
)
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
@@ -1175,7 +1171,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
self.transformer_temporal_patch_size,
)
# We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
latent_chunk = latent_chunk[:, :, 1:latent_tile_num_frames, :, :]
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames + 1:, :, :]
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(latent_chunk, first_tile_out_latents, adain_factor)
alpha = torch.linspace(1, 0, temporal_overlap + 1, device=latent_chunk.device)[1:-1]