mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
refactor 2
This commit is contained in:
@@ -979,6 +979,11 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
last_latent_tile_num_frames = last_latent_chunk.shape[2]
|
||||
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
|
||||
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
|
||||
|
||||
conditioning_mask = torch.zeros(
|
||||
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
|
||||
)
|
||||
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
|
||||
else:
|
||||
total_latent_num_frames = latent_tile_num_frames
|
||||
|
||||
@@ -998,6 +1003,21 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
device=device,
|
||||
)
|
||||
|
||||
if start_index > 0:
|
||||
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
|
||||
|
||||
video_ids = self._scale_video_ids(
|
||||
video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
scale_factor_t=self.vae_temporal_compression_ratio,
|
||||
frame_index=0,
|
||||
device=device
|
||||
)
|
||||
video_ids = video_ids.float()
|
||||
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
|
||||
if self.do_classifier_free_guidance:
|
||||
video_ids = torch.cat([video_ids, video_ids], dim=0)
|
||||
|
||||
# Set timesteps
|
||||
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
sigmas = self.scheduler.sigmas
|
||||
@@ -1005,18 +1025,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
self._num_timesteps = len(inner_timesteps)
|
||||
|
||||
if start_index == 0:
|
||||
video_ids = self._scale_video_ids(
|
||||
video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
scale_factor_t=self.vae_temporal_compression_ratio,
|
||||
frame_index=0,
|
||||
device=device
|
||||
)
|
||||
video_ids = video_ids.float()
|
||||
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
|
||||
if self.do_classifier_free_guidance:
|
||||
video_ids = torch.cat([video_ids, video_ids], dim=0)
|
||||
|
||||
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(inner_timesteps):
|
||||
if self.interrupt:
|
||||
@@ -1079,24 +1087,6 @@ class LTXConditionInfinitePipeline(DiffusionPipeline, FromSingleFileMixin, LTXVi
|
||||
)
|
||||
first_tile_out_latents = tile_out_latents.clone()
|
||||
else:
|
||||
conditioning_mask = torch.zeros(
|
||||
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
|
||||
)
|
||||
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
|
||||
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
|
||||
|
||||
video_ids = self._scale_video_ids(
|
||||
video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
scale_factor_t=self.vae_temporal_compression_ratio,
|
||||
frame_index=0,
|
||||
device=device
|
||||
)
|
||||
video_ids = video_ids.float()
|
||||
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
|
||||
if self.do_classifier_free_guidance:
|
||||
video_ids = torch.cat([video_ids, video_ids], dim=0)
|
||||
|
||||
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(inner_timesteps):
|
||||
if self.interrupt:
|
||||
|
||||
Reference in New Issue
Block a user