1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refactor 1

This commit is contained in:
Aryan
2025-08-14 01:12:13 +02:00
parent e981399b81
commit 47bf390bb4

View File

@@ -973,6 +973,30 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
):
latent_chunk = self._select_latents(tile_latents, start_index, min(end_index - 1, tile_latents.shape[2] - 1))
latent_tile_num_frames = latent_chunk.shape[2]
if start_index > 0:
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
last_latent_tile_num_frames = last_latent_chunk.shape[2]
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
else:
total_latent_num_frames = latent_tile_num_frames
latent_chunk = self._pack_latents(
latent_chunk,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
video_ids = self._prepare_video_ids(
batch_size,
total_latent_num_frames,
latent_tile_height,
latent_tile_width,
patch_size_t=self.transformer_temporal_patch_size,
patch_size=self.transformer_spatial_patch_size,
device=device,
)
# Set timesteps
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -981,17 +1005,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
self._num_timesteps = len(inner_timesteps)
if start_index == 0:
latent_chunk = self._pack_latents(latent_chunk, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
video_ids = self._prepare_video_ids(
batch_size,
latent_tile_num_frames,
latent_tile_height,
latent_tile_width,
patch_size_t=self.transformer_temporal_patch_size,
patch_size=self.transformer_spatial_patch_size,
device=device,
)
video_ids = self._scale_video_ids(
video_ids,
scale_factor=self.vae_spatial_compression_ratio,
@@ -1066,26 +1079,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
)
first_tile_out_latents = tile_out_latents.clone()
else:
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
last_latent_tile_num_frames = last_latent_chunk.shape[2]
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
latent_chunk = self._pack_latents(
latent_chunk,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
video_ids = self._prepare_video_ids(
batch_size,
total_latent_num_frames,
latent_tile_height,
latent_tile_width,
patch_size_t=self.transformer_temporal_patch_size,
patch_size=self.transformer_spatial_patch_size,
device=device,
)
conditioning_mask = torch.zeros(
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
)