mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
this version kinda works but results still bad
This commit is contained in:
@@ -396,7 +396,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
negative_prompt_attention_mask=None,
|
||||
temporal_tile_size=None,
|
||||
temporal_overlap=None,
|
||||
temporal_overlap_cond_strength=None,
|
||||
horizontal_tiles=None,
|
||||
vertical_tiles=None,
|
||||
spatial_overlap=None,
|
||||
@@ -449,10 +448,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
)
|
||||
if temporal_overlap < 16 or temporal_overlap > 80 or temporal_overlap % 8 != 0:
|
||||
raise ValueError(f"`temporal_overlap` must be in [16, 80] and divisible by 8 but is {temporal_overlap}.")
|
||||
if not (0.0 <= temporal_overlap_cond_strength <= 1.0):
|
||||
raise ValueError(
|
||||
f"`temporal_overlap_cond_strength` must be in [0.0, 1.0] but is {temporal_overlap_cond_strength}."
|
||||
)
|
||||
if not (1 <= horizontal_tiles <= 6):
|
||||
raise ValueError(f"`horizontal_tiles` must be between 1 and 6 but is {horizontal_tiles}.")
|
||||
if not (1 <= vertical_tiles <= 6):
|
||||
@@ -573,7 +568,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
end_idx = max(0, min(end_idx, num_frames - 1))
|
||||
if start_idx > end_idx:
|
||||
start_idx = min(start_idx, end_idx)
|
||||
latents = latents[:, :, start_idx : end_idx + 1, :, :]
|
||||
latents = latents[:, :, start_idx : end_idx + 1, :, :].clone()
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
@@ -742,7 +737,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3,
|
||||
guidance_rescale: float = 0.0,
|
||||
image_cond_noise_scale: float = 0.15,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -762,7 +756,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
max_sequence_length: int = 256,
|
||||
temporal_tile_size: int = 80,
|
||||
temporal_overlap: int = 24,
|
||||
temporal_overlap_cond_strength: float = 0.5,
|
||||
horizontal_tiles: int = 1,
|
||||
vertical_tiles: int = 1,
|
||||
spatial_overlap: int = 1,
|
||||
@@ -846,8 +839,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
region.
|
||||
temporal_overlap (`int`, defaults to `24`):
|
||||
The overlap between the temporal tiles, in pixel frames.
|
||||
temporal_overlap_cond_strength (`float`, defaults to `0.5`):
|
||||
The strength of the conditioning on the latents from the previous temporal tile.
|
||||
horizontal_tiles (`int`, defaults to `1`):
|
||||
Number of horizontal spatial tiles to use for the sampling.
|
||||
vertical_tiles (`int`, defaults to `1`):
|
||||
@@ -878,7 +869,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
temporal_tile_size=temporal_tile_size,
|
||||
temporal_overlap=temporal_overlap,
|
||||
temporal_overlap_cond_strength=temporal_overlap_cond_strength,
|
||||
horizontal_tiles=horizontal_tiles,
|
||||
vertical_tiles=vertical_tiles,
|
||||
spatial_overlap=spatial_overlap,
|
||||
@@ -1076,8 +1066,8 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
else:
|
||||
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
|
||||
last_latent_tile_num_frames = last_latent_chunk.shape[2]
|
||||
latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2)
|
||||
total_latent_num_frames = latent_tile_num_frames + last_latent_tile_num_frames
|
||||
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
|
||||
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
|
||||
latent_chunk = self._pack_latents(
|
||||
latent_chunk,
|
||||
self.transformer_spatial_patch_size,
|
||||
@@ -1094,10 +1084,10 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
device=device,
|
||||
)
|
||||
|
||||
conditioning_mask = torch.ones(
|
||||
conditioning_mask = torch.zeros(
|
||||
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
|
||||
)
|
||||
conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength
|
||||
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
|
||||
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
|
||||
|
||||
video_ids = self._scale_video_ids(
|
||||
@@ -1120,7 +1110,13 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
conditioning_mask_model_input = (
|
||||
torch.cat([conditioning_mask, conditioning_mask])
|
||||
if self.do_classifier_free_guidance
|
||||
else conditioning_mask
|
||||
)
|
||||
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
@@ -1175,7 +1171,7 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
# We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
|
||||
latent_chunk = latent_chunk[:, :, 1:latent_tile_num_frames, :, :]
|
||||
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames + 1:, :, :]
|
||||
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(latent_chunk, first_tile_out_latents, adain_factor)
|
||||
|
||||
alpha = torch.linspace(1, 0, temporal_overlap + 1, device=latent_chunk.device)[1:-1]
|
||||
|
||||
Reference in New Issue
Block a user