You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
484 lines
24 KiB
Python
484 lines
24 KiB
Python
import torch
|
|
import os
|
|
import gc
|
|
from PIL import Image
|
|
import numpy as np
|
|
from ..latent_preview import prepare_callback
|
|
from ..wanvideo.schedulers import get_scheduler
|
|
from .multitalk import timestep_transform, add_noise
|
|
from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap
|
|
from comfy.utils import load_torch_file
|
|
from ..nodes_model_loading import load_weights
|
|
from ..HuMo.nodes import get_audio_emb_window
|
|
import comfy.model_management as mm
|
|
from tqdm import tqdm
|
|
import copy
|
|
|
|
VAE_STRIDE = (4, 8, 8)
|
|
PATCH_SIZE = (1, 2, 2)
|
|
vae_upscale_factor = 8
|
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
device = mm.get_torch_device()
|
|
offload_device = mm.unet_offload_device()
|
|
|
|
def multitalk_loop(self, **kwargs):
|
|
# Unpack kwargs into local variables
|
|
(latent, total_steps, steps, start_step, end_step, shift, cfg, denoise_strength,
|
|
sigmas, weight_dtype, transformer, patcher, block_swap_args, model, vae, dtype,
|
|
scheduler, scheduler_step_args, text_embeds, image_embeds, multitalk_embeds,
|
|
multitalk_audio_embeds, unianim_data, dwpose_data, unianimate_poses, uni3c_embeds,
|
|
humo_image_cond, humo_image_cond_neg, humo_audio, humo_reference_count,
|
|
add_noise_to_samples, audio_stride, use_tsr, tsr_k, tsr_sigma, fantasy_portrait_input,
|
|
noise, timesteps, force_offload, add_cond, control_latents, audio_proj,
|
|
control_camera_latents, samples, masks, seed_g, gguf_reader, predict_func
|
|
) = (kwargs.get(k) for k in (
|
|
'latent', 'total_steps', 'steps', 'start_step', 'end_step', 'shift', 'cfg',
|
|
'denoise_strength', 'sigmas', 'weight_dtype', 'transformer', 'patcher',
|
|
'block_swap_args', 'model', 'vae', 'dtype', 'scheduler', 'scheduler_step_args',
|
|
'text_embeds', 'image_embeds', 'multitalk_embeds', 'multitalk_audio_embeds',
|
|
'unianim_data', 'dwpose_data', 'unianimate_poses', 'uni3c_embeds',
|
|
'humo_image_cond', 'humo_image_cond_neg', 'humo_audio', 'humo_reference_count',
|
|
'add_noise_to_samples', 'audio_stride', 'use_tsr', 'tsr_k', 'tsr_sigma',
|
|
'fantasy_portrait_input', 'noise', 'timesteps', 'force_offload', 'add_cond',
|
|
'control_latents', 'audio_proj', 'control_camera_latents', 'samples', 'masks',
|
|
'seed_g', 'gguf_reader', 'predict_with_cfg'
|
|
))
|
|
|
|
mode = image_embeds.get("multitalk_mode", "multitalk")
|
|
if mode == "auto":
|
|
mode = transformer.multitalk_model_type.lower()
|
|
log.info(f"Multitalk mode: {mode}")
|
|
cond_frame = None
|
|
offload = image_embeds.get("force_offload", False)
|
|
offloaded = False
|
|
tiled_vae = image_embeds.get("tiled_vae", False)
|
|
frame_num = clip_length = image_embeds.get("frame_window_size", 81)
|
|
|
|
clip_embeds = image_embeds.get("clip_context", None)
|
|
if clip_embeds is not None:
|
|
clip_embeds = clip_embeds.to(dtype)
|
|
colormatch = image_embeds.get("colormatch", "disabled")
|
|
motion_frame = image_embeds.get("motion_frame", 25)
|
|
target_w = image_embeds.get("target_w", None)
|
|
target_h = image_embeds.get("target_h", None)
|
|
original_images = cond_image = image_embeds.get("multitalk_start_image", None)
|
|
if original_images is None:
|
|
original_images = torch.zeros([noise.shape[0], 1, target_h, target_w], device=device)
|
|
|
|
output_path = image_embeds.get("output_path", "")
|
|
img_counter = 0
|
|
|
|
if len(multitalk_embeds['audio_features'])==2 and (multitalk_embeds['ref_target_masks'] is None):
|
|
face_scale = 0.1
|
|
x_min, x_max = int(target_h * face_scale), int(target_h * (1 - face_scale))
|
|
lefty_min, lefty_max = int((target_w//2) * face_scale), int((target_w//2) * (1 - face_scale))
|
|
righty_min, righty_max = int((target_w//2) * face_scale + (target_w//2)), int((target_w//2) * (1 - face_scale) + (target_w//2))
|
|
human_mask1, human_mask2 = (torch.zeros([target_h, target_w]) for _ in range(2))
|
|
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
|
|
human_mask2[x_min:x_max, righty_min:righty_max] = 1
|
|
background_mask = torch.where((human_mask1 + human_mask2) > 0, torch.tensor(0), torch.tensor(1))
|
|
human_masks = [human_mask1, human_mask2, background_mask]
|
|
ref_target_masks = torch.stack(human_masks, dim=0)
|
|
multitalk_embeds['ref_target_masks'] = ref_target_masks
|
|
|
|
gen_video_list = []
|
|
is_first_clip = True
|
|
arrive_last_frame = False
|
|
cur_motion_frames_num = 1
|
|
audio_start_idx = iteration_count = step_iteration_count = 0
|
|
audio_end_idx = (audio_start_idx + clip_length) * audio_stride
|
|
indices = (torch.arange(4 + 1) - 2) * 1
|
|
current_condframe_index = 0
|
|
|
|
audio_embedding = multitalk_audio_embeds
|
|
human_num = len(audio_embedding)
|
|
audio_embs = None
|
|
cond_frame = None
|
|
|
|
uni3c_data = None
|
|
if uni3c_embeds is not None:
|
|
transformer.controlnet = uni3c_embeds["controlnet"]
|
|
uni3c_data = uni3c_embeds.copy()
|
|
|
|
encoded_silence = None
|
|
|
|
try:
|
|
silence_path = os.path.join(script_directory, "encoded_silence.safetensors")
|
|
encoded_silence = load_torch_file(silence_path)["audio_emb"].to(dtype)
|
|
except:
|
|
log.warning("No encoded silence file found, padding with end of audio embedding instead.")
|
|
|
|
total_frames = len(audio_embedding[0])
|
|
estimated_iterations = total_frames // (frame_num - motion_frame) + 1
|
|
callback = prepare_callback(patcher, estimated_iterations)
|
|
|
|
if frame_num >= total_frames:
|
|
arrive_last_frame = True
|
|
estimated_iterations = 1
|
|
|
|
log.info(f"Sampling {total_frames} frames in {estimated_iterations} windows, at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps")
|
|
|
|
while True: # start video generation iteratively
|
|
self.cache_state = [None, None]
|
|
|
|
cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4)
|
|
if mode == "infinitetalk":
|
|
cond_image = original_images[:, :, current_condframe_index:current_condframe_index+1] if cond_image is not None else None
|
|
if multitalk_embeds is not None:
|
|
audio_embs = []
|
|
# split audio with window size
|
|
for human_idx in range(human_num):
|
|
center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + indices.unsqueeze(0)
|
|
center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0]-1)
|
|
audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device)
|
|
audio_embs.append(audio_emb)
|
|
audio_embs = torch.concat(audio_embs, dim=0).to(dtype)
|
|
|
|
h, w = (cond_image.shape[-2], cond_image.shape[-1]) if cond_image is not None else (target_h, target_w)
|
|
lat_h, lat_w = h // VAE_STRIDE[1], w // VAE_STRIDE[2]
|
|
latent_frame_num = (frame_num - 1) // 4 + 1
|
|
|
|
noise = torch.randn(
|
|
16, latent_frame_num,
|
|
lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device)
|
|
|
|
# Calculate the correct latent slice based on current iteration
|
|
if is_first_clip:
|
|
latent_start_idx = 0
|
|
latent_end_idx = noise.shape[1]
|
|
else:
|
|
new_frames_per_iteration = frame_num - motion_frame
|
|
new_latent_frames_per_iteration = ((new_frames_per_iteration - 1) // 4 + 1)
|
|
latent_start_idx = iteration_count * new_latent_frames_per_iteration
|
|
latent_end_idx = latent_start_idx + noise.shape[1]
|
|
|
|
if samples is not None:
|
|
noise_mask = samples.get("noise_mask", None)
|
|
input_samples = samples["samples"]
|
|
if input_samples is not None:
|
|
input_samples = input_samples.squeeze(0).to(noise)
|
|
# Check if we have enough frames in input_samples
|
|
if latent_end_idx > input_samples.shape[1]:
|
|
# We need more frames than available - pad the input_samples at the end
|
|
pad_length = latent_end_idx - input_samples.shape[1]
|
|
last_frame = input_samples[:, -1:].repeat(1, pad_length, 1, 1)
|
|
input_samples = torch.cat([input_samples, last_frame], dim=1)
|
|
input_samples = input_samples[:, latent_start_idx:latent_end_idx]
|
|
if noise_mask is not None:
|
|
original_image = input_samples.to(device)
|
|
|
|
assert input_samples.shape[1] == noise.shape[1], f"Slice mismatch: {input_samples.shape[1]} vs {noise.shape[1]}"
|
|
|
|
if add_noise_to_samples:
|
|
latent_timestep = timesteps[0]
|
|
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
|
|
else:
|
|
noise = input_samples
|
|
|
|
# diff diff prep
|
|
if noise_mask is not None:
|
|
if len(noise_mask.shape) == 4:
|
|
noise_mask = noise_mask.squeeze(1)
|
|
if audio_end_idx > noise_mask.shape[0]:
|
|
noise_mask = noise_mask.repeat(audio_end_idx // noise_mask.shape[0], 1, 1)
|
|
noise_mask = noise_mask[audio_start_idx:audio_end_idx]
|
|
noise_mask = torch.nn.functional.interpolate(
|
|
noise_mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W]
|
|
size=(noise.shape[1], noise.shape[2], noise.shape[3]),
|
|
mode='trilinear',
|
|
align_corners=False
|
|
).repeat(1, noise.shape[0], 1, 1, 1)
|
|
|
|
thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps)
|
|
thresholds = thresholds.reshape(-1, 1, 1, 1, 1).to(device)
|
|
masks = (1-noise_mask.repeat(len(timesteps), 1, 1, 1, 1).to(device)) > thresholds
|
|
|
|
# zero padding and vae encode for img cond
|
|
if cond_image is not None or cond_frame is not None:
|
|
cond_ = cond_image if (is_first_clip or humo_image_cond is None) else cond_frame
|
|
cond_frame_num = cond_.shape[2]
|
|
video_frames = torch.zeros(1, 3, frame_num-cond_frame_num, target_h, target_w, device=device, dtype=vae.dtype)
|
|
padding_frames_pixels_values = torch.concat([cond_.to(device, vae.dtype), video_frames], dim=2)
|
|
|
|
# encode
|
|
vae.to(device)
|
|
y = vae.encode(padding_frames_pixels_values, device=device, tiled=tiled_vae, pbar=False).to(dtype)[0]
|
|
|
|
if mode == "multitalk":
|
|
latent_motion_frames = y[:, :cur_motion_frames_latent_num] # C T H W
|
|
else:
|
|
cond_ = cond_image if is_first_clip else cond_frame
|
|
latent_motion_frames = vae.encode(cond_.to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False).to(dtype)[0]
|
|
|
|
vae.to(offload_device)
|
|
|
|
#motion_frame_index = cur_motion_frames_latent_num if mode == "infinitetalk" else 1
|
|
msk = torch.zeros(4, latent_frame_num, lat_h, lat_w, device=device, dtype=dtype)
|
|
msk[:, :1] = 1
|
|
y = torch.cat([msk, y]) # 4+C T H W
|
|
mm.soft_empty_cache()
|
|
else:
|
|
y = None
|
|
latent_motion_frames = noise[:, :1]
|
|
|
|
partial_humo_cond_input = partial_humo_cond_neg_input = partial_humo_audio = partial_humo_audio_neg = None
|
|
if humo_image_cond is not None:
|
|
partial_humo_cond_input = humo_image_cond[:, :latent_frame_num]
|
|
partial_humo_cond_neg_input = humo_image_cond_neg[:, :latent_frame_num]
|
|
if y is not None:
|
|
partial_humo_cond_input[:, :1] = y[:, :1]
|
|
if humo_reference_count > 0:
|
|
partial_humo_cond_input[:, -humo_reference_count:] = humo_image_cond[:, -humo_reference_count:]
|
|
partial_humo_cond_neg_input[:, -humo_reference_count:] = humo_image_cond_neg[:, -humo_reference_count:]
|
|
|
|
if humo_audio is not None:
|
|
if is_first_clip:
|
|
audio_embs = None
|
|
|
|
partial_humo_audio, _ = get_audio_emb_window(humo_audio, frame_num, frame0_idx=audio_start_idx)
|
|
#zero_audio_pad = torch.zeros(humo_reference_count, *partial_humo_audio.shape[1:], device=partial_humo_audio.device, dtype=partial_humo_audio.dtype)
|
|
partial_humo_audio[-humo_reference_count:] = 0
|
|
partial_humo_audio_neg = torch.zeros_like(partial_humo_audio, device=partial_humo_audio.device, dtype=partial_humo_audio.dtype)
|
|
|
|
if scheduler == "multitalk":
|
|
timesteps = list(np.linspace(1000, 1, steps, dtype=np.float32))
|
|
timesteps.append(0.)
|
|
timesteps = [torch.tensor([t], device=device) for t in timesteps]
|
|
timesteps = [timestep_transform(t, shift=shift, num_timesteps=1000) for t in timesteps]
|
|
else:
|
|
if isinstance(scheduler, dict):
|
|
sample_scheduler = copy.deepcopy(scheduler["sample_scheduler"])
|
|
timesteps = scheduler["timesteps"]
|
|
else:
|
|
sample_scheduler, timesteps,_,_ = get_scheduler(scheduler, total_steps, start_step, end_step, shift, device, transformer.dim, denoise_strength, sigmas=sigmas)
|
|
timesteps = [torch.tensor([float(t)], device=device) for t in timesteps] + [torch.tensor([0.], device=device)]
|
|
|
|
# sample videos
|
|
latent = noise
|
|
|
|
# injecting motion frames
|
|
if not is_first_clip and mode == "multitalk":
|
|
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device)
|
|
motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous()
|
|
add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[0])
|
|
latent[:, :add_latent.shape[1]] = add_latent
|
|
|
|
if offloaded:
|
|
# Load weights
|
|
if transformer.patched_linear and gguf_reader is None:
|
|
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
|
|
elif gguf_reader is not None: #handle GGUF
|
|
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
|
|
#blockswap init
|
|
init_blockswap(transformer, block_swap_args, model)
|
|
|
|
# Use the appropriate prompt for this section
|
|
if len(text_embeds["prompt_embeds"]) > 1:
|
|
prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1)
|
|
positive = [text_embeds["prompt_embeds"][prompt_index]]
|
|
log.info(f"Using prompt index: {prompt_index}")
|
|
else:
|
|
positive = text_embeds["prompt_embeds"]
|
|
|
|
# uni3c slices
|
|
if uni3c_embeds is not None:
|
|
vae.to(device)
|
|
# Pad original_images if needed
|
|
num_frames = original_images.shape[2]
|
|
if audio_end_idx > num_frames:
|
|
pad_len = audio_end_idx - num_frames
|
|
last_frame = original_images[:, :, -1:].repeat(1, 1, pad_len, 1, 1)
|
|
padded_images = torch.cat([original_images, last_frame], dim=2)
|
|
else:
|
|
padded_images = original_images
|
|
render_latent = vae.encode(
|
|
padded_images[:, :, audio_start_idx:audio_end_idx].to(device, vae.dtype),
|
|
device=device, tiled=tiled_vae
|
|
).to(dtype)
|
|
|
|
vae.to(offload_device)
|
|
uni3c_data['render_latent'] = render_latent
|
|
|
|
# unianimate slices
|
|
partial_unianim_data = None
|
|
if unianim_data is not None:
|
|
partial_dwpose = dwpose_data[:, :, latent_start_idx:latent_end_idx]
|
|
partial_unianim_data = {
|
|
"dwpose": partial_dwpose,
|
|
"random_ref": unianim_data["random_ref"],
|
|
"strength": unianimate_poses["strength"],
|
|
"start_percent": unianimate_poses["start_percent"],
|
|
"end_percent": unianimate_poses["end_percent"]
|
|
}
|
|
|
|
# fantasy portrait slices
|
|
partial_fantasy_portrait_input = None
|
|
if fantasy_portrait_input is not None:
|
|
adapter_proj = fantasy_portrait_input["adapter_proj"]
|
|
if latent_end_idx > adapter_proj.shape[1]:
|
|
pad_len = latent_end_idx - adapter_proj.shape[1]
|
|
last_frame = adapter_proj[:, -1:, :, :].repeat(1, pad_len, 1, 1)
|
|
padded_proj = torch.cat([adapter_proj, last_frame], dim=1)
|
|
else:
|
|
padded_proj = adapter_proj
|
|
partial_fantasy_portrait_input = fantasy_portrait_input.copy()
|
|
partial_fantasy_portrait_input["adapter_proj"] = padded_proj[:, latent_start_idx:latent_end_idx]
|
|
|
|
mm.soft_empty_cache()
|
|
gc.collect()
|
|
# sampling loop
|
|
sampling_pbar = tqdm(total=len(timesteps)-1, desc=f"Sampling audio indices {audio_start_idx}-{audio_end_idx}", position=0, leave=True)
|
|
for i in range(len(timesteps)-1):
|
|
timestep = timesteps[i]
|
|
latent_model_input = latent.to(device)
|
|
if mode == "infinitetalk":
|
|
if humo_image_cond is None or not is_first_clip:
|
|
latent_model_input[:, :cur_motion_frames_latent_num] = latent_motion_frames
|
|
|
|
noise_pred, _, self.cache_state = predict_func(
|
|
latent_model_input, cfg[min(i, len(timesteps)-1)], positive, text_embeds["negative_prompt_embeds"],
|
|
timestep, i, y, clip_embeds, control_latents, None, partial_unianim_data, audio_proj, control_camera_latents, add_cond,
|
|
cache_state=self.cache_state, multitalk_audio_embeds=audio_embs, fantasy_portrait_input=partial_fantasy_portrait_input,
|
|
humo_image_cond=partial_humo_cond_input, humo_image_cond_neg=partial_humo_cond_neg_input, humo_audio=partial_humo_audio, humo_audio_neg=partial_humo_audio_neg,
|
|
uni3c_data = uni3c_data)
|
|
|
|
if callback is not None:
|
|
callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * timestep.to(device) / 1000).detach().permute(1,0,2,3)
|
|
callback(step_iteration_count, callback_latent, None, estimated_iterations*(len(timesteps)-1))
|
|
del callback_latent
|
|
|
|
sampling_pbar.update(1)
|
|
step_iteration_count += 1
|
|
|
|
# update latent
|
|
if use_tsr:
|
|
noise_pred = temporal_score_rescaling(noise_pred, latent, timestep, tsr_k, tsr_sigma)
|
|
if scheduler == "multitalk":
|
|
noise_pred = -noise_pred
|
|
dt = (timesteps[i] - timesteps[i + 1]) / 1000
|
|
latent = latent + noise_pred * dt[:, None, None, None]
|
|
else:
|
|
latent = sample_scheduler.step(noise_pred.unsqueeze(0), timestep, latent.unsqueeze(0).to(noise_pred.device), **scheduler_step_args)[0].squeeze(0)
|
|
del noise_pred, latent_model_input, timestep
|
|
|
|
# differential diffusion inpaint
|
|
if masks is not None:
|
|
if i < len(timesteps) - 1:
|
|
image_latent = add_noise(original_image.to(device), noise.to(device), timesteps[i+1])
|
|
mask = masks[i].to(latent)
|
|
latent = image_latent * mask + latent * (1-mask)
|
|
|
|
# injecting motion frames
|
|
if not is_first_clip and mode == "multitalk":
|
|
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device)
|
|
motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous()
|
|
add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1])
|
|
latent[:, :add_latent.shape[1]] = add_latent
|
|
else:
|
|
if humo_image_cond is None or not is_first_clip:
|
|
latent[:, :cur_motion_frames_latent_num] = latent_motion_frames
|
|
|
|
del noise, latent_motion_frames
|
|
if offload:
|
|
offload_transformer(transformer, remove_lora=False)
|
|
offloaded = True
|
|
if humo_image_cond is not None and humo_reference_count > 0:
|
|
latent = latent[:,:-humo_reference_count]
|
|
vae.to(device)
|
|
videos = vae.decode(latent.unsqueeze(0).to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False)[0].cpu()
|
|
|
|
vae.to(offload_device)
|
|
|
|
sampling_pbar.close()
|
|
|
|
# optional color correction (less relevant for InfiniteTalk)
|
|
if colormatch != "disabled":
|
|
videos = videos.permute(1, 2, 3, 0).float().numpy()
|
|
from color_matcher import ColorMatcher
|
|
cm = ColorMatcher()
|
|
cm_result_list = []
|
|
for img in videos:
|
|
if mode == "multitalk":
|
|
cm_result = cm.transfer(src=img, ref=original_images[0].permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
|
|
else:
|
|
cm_result = cm.transfer(src=img, ref=cond_image[0].permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
|
|
cm_result_list.append(torch.from_numpy(cm_result).to(vae.dtype))
|
|
|
|
videos = torch.stack(cm_result_list, dim=0).permute(3, 0, 1, 2)
|
|
|
|
# optionally save generated samples to disk
|
|
if output_path:
|
|
video_np = videos.clamp(-1.0, 1.0).add(1.0).div(2.0).mul(255).cpu().float().numpy().transpose(1, 2, 3, 0).astype('uint8')
|
|
num_frames_to_save = video_np.shape[0] if is_first_clip else video_np.shape[0] - cur_motion_frames_num
|
|
log.info(f"Saving {num_frames_to_save} generated frames to {output_path}")
|
|
start_idx = 0 if is_first_clip else cur_motion_frames_num
|
|
for i in range(start_idx, video_np.shape[0]):
|
|
im = Image.fromarray(video_np[i])
|
|
im.save(os.path.join(output_path, f"frame_{img_counter:05d}.png"))
|
|
img_counter += 1
|
|
else:
|
|
gen_video_list.append(videos if is_first_clip else videos[:, cur_motion_frames_num:])
|
|
|
|
current_condframe_index += 1
|
|
iteration_count += 1
|
|
|
|
# decide whether is done
|
|
if arrive_last_frame:
|
|
break
|
|
|
|
# update next condition frames
|
|
is_first_clip = False
|
|
cur_motion_frames_num = motion_frame
|
|
|
|
cond_ = videos[:, -cur_motion_frames_num:].unsqueeze(0)
|
|
if mode == "infinitetalk":
|
|
cond_frame = cond_
|
|
else:
|
|
cond_image = cond_
|
|
|
|
del videos, latent
|
|
|
|
# Repeat audio emb
|
|
if multitalk_embeds is not None:
|
|
audio_start_idx += (frame_num - cur_motion_frames_num - humo_reference_count)
|
|
audio_end_idx = audio_start_idx + clip_length
|
|
if audio_end_idx >= len(audio_embedding[0]):
|
|
arrive_last_frame = True
|
|
miss_lengths = []
|
|
source_frames = []
|
|
for human_inx in range(human_num):
|
|
source_frame = len(audio_embedding[human_inx])
|
|
source_frames.append(source_frame)
|
|
if audio_end_idx >= len(audio_embedding[human_inx]):
|
|
log.warning(f"Audio embedding for subject {human_inx} not long enough: {len(audio_embedding[human_inx])}, need {audio_end_idx}, padding...")
|
|
miss_length = audio_end_idx - len(audio_embedding[human_inx]) + 3
|
|
log.warning(f"Padding length: {miss_length}")
|
|
if encoded_silence is not None:
|
|
add_audio_emb = encoded_silence[-1*miss_length:]
|
|
else:
|
|
add_audio_emb = torch.flip(audio_embedding[human_inx][-1*miss_length:], dims=[0])
|
|
audio_embedding[human_inx] = torch.cat([audio_embedding[human_inx], add_audio_emb.to(device, dtype)], dim=0)
|
|
miss_lengths.append(miss_length)
|
|
else:
|
|
miss_lengths.append(0)
|
|
if mode == "infinitetalk" and current_condframe_index >= original_images.shape[2]:
|
|
last_frame = original_images[:, :, -1:, :, :]
|
|
miss_length = 1
|
|
original_images = torch.cat([original_images, last_frame.repeat(1, 1, miss_length, 1, 1)], dim=2)
|
|
|
|
if not output_path:
|
|
gen_video_samples = torch.cat(gen_video_list, dim=1)
|
|
else:
|
|
gen_video_samples = torch.zeros(3, 1, 64, 64) # dummy output
|
|
|
|
if force_offload:
|
|
if not model["auto_cpu_offload"]:
|
|
offload_transformer(transformer)
|
|
try:
|
|
print_memory(device)
|
|
torch.cuda.reset_peak_memory_stats(device)
|
|
except:
|
|
pass
|
|
return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path},
|