mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor 1
This commit is contained in:
@@ -973,6 +973,30 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
):
|
||||
latent_chunk = self._select_latents(tile_latents, start_index, min(end_index - 1, tile_latents.shape[2] - 1))
|
||||
latent_tile_num_frames = latent_chunk.shape[2]
|
||||
|
||||
if start_index > 0:
|
||||
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
|
||||
last_latent_tile_num_frames = last_latent_chunk.shape[2]
|
||||
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
|
||||
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
|
||||
else:
|
||||
total_latent_num_frames = latent_tile_num_frames
|
||||
|
||||
latent_chunk = self._pack_latents(
|
||||
latent_chunk,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
video_ids = self._prepare_video_ids(
|
||||
batch_size,
|
||||
total_latent_num_frames,
|
||||
latent_tile_height,
|
||||
latent_tile_width,
|
||||
patch_size_t=self.transformer_temporal_patch_size,
|
||||
patch_size=self.transformer_spatial_patch_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set timesteps
|
||||
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
@@ -981,17 +1005,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
self._num_timesteps = len(inner_timesteps)
|
||||
|
||||
if start_index == 0:
|
||||
latent_chunk = self._pack_latents(latent_chunk, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
|
||||
|
||||
video_ids = self._prepare_video_ids(
|
||||
batch_size,
|
||||
latent_tile_num_frames,
|
||||
latent_tile_height,
|
||||
latent_tile_width,
|
||||
patch_size_t=self.transformer_temporal_patch_size,
|
||||
patch_size=self.transformer_spatial_patch_size,
|
||||
device=device,
|
||||
)
|
||||
video_ids = self._scale_video_ids(
|
||||
video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
@@ -1066,26 +1079,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
)
|
||||
first_tile_out_latents = tile_out_latents.clone()
|
||||
else:
|
||||
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
|
||||
last_latent_tile_num_frames = last_latent_chunk.shape[2]
|
||||
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
|
||||
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
|
||||
latent_chunk = self._pack_latents(
|
||||
latent_chunk,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
video_ids = self._prepare_video_ids(
|
||||
batch_size,
|
||||
total_latent_num_frames,
|
||||
latent_tile_height,
|
||||
latent_tile_width,
|
||||
patch_size_t=self.transformer_temporal_patch_size,
|
||||
patch_size=self.transformer_spatial_patch_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
conditioning_mask = torch.zeros(
|
||||
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user