mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Get latent upsampler working with video latents
This commit is contained in:
@@ -53,6 +53,7 @@ def parse_args():
|
||||
parser.add_argument("--dtype", type=str, default="bf16")
|
||||
parser.add_argument("--cpu_offload", action="store_true")
|
||||
parser.add_argument("--vae_tiling", action="store_true")
|
||||
parser.add_argument("--use_video_latents", action="store_true")
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -83,6 +84,10 @@ def main(args):
|
||||
|
||||
image = load_image(args.image_path)
|
||||
|
||||
first_stage_output_type = "pil"
|
||||
if args.use_video_latents:
|
||||
first_stage_output_type = "latent"
|
||||
|
||||
video, audio = pipeline(
|
||||
image=image,
|
||||
prompt=args.prompt,
|
||||
@@ -94,15 +99,39 @@ def main(args):
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
generator=torch.Generator(device=args.device).manual_seed(args.seed),
|
||||
output_type="pil",
|
||||
output_type=first_stage_output_type,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if args.use_video_latents:
|
||||
# Manually convert the audio latents to a waveform
|
||||
audio = audio.to(pipeline.audio_vae.dtype)
|
||||
audio = pipeline._denormalize_audio_latents(
|
||||
audio, pipeline.audio_vae.latents_mean, pipeline.audio_vae.latents_std
|
||||
)
|
||||
|
||||
sampling_rate = pipeline.audio_sampling_rate
|
||||
hop_length = pipeline.audio_hop_length
|
||||
audio_vae_temporal_scale = pipeline.audio_vae_temporal_compression_ratio
|
||||
duration_s = args.num_frames / args.frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(audio_vae_temporal_scale)
|
||||
)
|
||||
audio_latent_frames = int(duration_s * latents_per_second)
|
||||
latent_mel_bins = pipeline.audio_vae.config.mel_bins // pipeline.audio_vae_mel_compression_ratio
|
||||
audio = pipeline._unpack_audio_latents(audio, audio_latent_frames, latent_mel_bins)
|
||||
|
||||
audio = pipeline.audio_vae.decode(audio, return_dict=False)[0]
|
||||
audio = pipeline.vocoder(audio)
|
||||
|
||||
# Get some pipeline configs for upsampling
|
||||
spatial_patch_size = pipeline.transformer_spatial_patch_size
|
||||
temporal_patch_size = pipeline.transformer_temporal_patch_size
|
||||
|
||||
# upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained(
|
||||
# args.model_id, revision=args.revision, torch_dtype=args.dtype,
|
||||
# )
|
||||
output_sampling_rate = pipeline.vocoder.config.output_sampling_rate
|
||||
pipeline.to(device="cpu")
|
||||
del pipeline # Otherwise there might be an OOM error?
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
@@ -124,13 +153,21 @@ def main(args):
|
||||
if args.vae_tiling:
|
||||
upsample_pipeline.enable_vae_tiling()
|
||||
|
||||
video = upsample_pipeline(
|
||||
video=video,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
upsample_kwargs = {
|
||||
"height": args.height,
|
||||
"width": args.width,
|
||||
"output_type": "np",
|
||||
"return_dict": False,
|
||||
}
|
||||
if args.use_video_latents:
|
||||
upsample_kwargs["latents"] = video
|
||||
upsample_kwargs["num_frames"] = args.num_frames
|
||||
upsample_kwargs["spatial_patch_size"] = spatial_patch_size
|
||||
upsample_kwargs["temporal_patch_size"] = temporal_patch_size
|
||||
else:
|
||||
upsample_kwargs["video"] = video
|
||||
|
||||
video = upsample_pipeline(**upsample_kwargs)[0]
|
||||
|
||||
# Convert video to uint8 (but keep as NumPy array)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
|
||||
@@ -67,12 +67,25 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_frames: int = 121,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 3:
|
||||
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
latents = self._unpack_latents(
|
||||
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
@@ -175,6 +188,19 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
@@ -246,6 +272,9 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
@@ -286,6 +315,11 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
spatial_patch_size=spatial_patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
|
||||
Reference in New Issue
Block a user