mirror of
https://github.com/Wan-Video/Wan2.2.git
synced 2026-01-27 10:22:46 +03:00
239 lines
7.1 KiB
Python
239 lines
7.1 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
import argparse
|
|
import binascii
|
|
import logging
|
|
import os
|
|
import os.path as osp
|
|
import shutil
|
|
import subprocess
|
|
|
|
import imageio
|
|
import torch
|
|
import torchvision
|
|
|
|
__all__ = ['save_video', 'save_image', 'str2bool']
|
|
|
|
|
|
def rand_name(length=8, suffix=''):
|
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
|
if suffix:
|
|
if not suffix.startswith('.'):
|
|
suffix = '.' + suffix
|
|
name += suffix
|
|
return name
|
|
|
|
|
|
def merge_video_audio(video_path: str, audio_path: str):
|
|
"""
|
|
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
|
and overwrite the original video file.
|
|
|
|
Parameters:
|
|
video_path (str): Path to the original video file
|
|
audio_path (str): Path to the audio file
|
|
"""
|
|
# set logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
# check
|
|
if not os.path.exists(video_path):
|
|
raise FileNotFoundError(f"video file {video_path} does not exist")
|
|
if not os.path.exists(audio_path):
|
|
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
|
|
|
base, ext = os.path.splitext(video_path)
|
|
temp_output = f"{base}_temp{ext}"
|
|
|
|
try:
|
|
# create ffmpeg command
|
|
command = [
|
|
'ffmpeg',
|
|
'-y', # overwrite
|
|
'-i',
|
|
video_path,
|
|
'-i',
|
|
audio_path,
|
|
'-c:v',
|
|
'copy', # copy video stream
|
|
'-c:a',
|
|
'aac', # use AAC audio encoder
|
|
'-b:a',
|
|
'192k', # set audio bitrate (optional)
|
|
'-map',
|
|
'0:v:0', # select the first video stream
|
|
'-map',
|
|
'1:a:0', # select the first audio stream
|
|
'-shortest', # choose the shortest duration
|
|
temp_output
|
|
]
|
|
|
|
# execute the command
|
|
logging.info("Start merging video and audio...")
|
|
result = subprocess.run(
|
|
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
|
|
# check result
|
|
if result.returncode != 0:
|
|
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
|
logging.error(error_msg)
|
|
raise RuntimeError(error_msg)
|
|
|
|
shutil.move(temp_output, video_path)
|
|
logging.info(f"Merge completed, saved to {video_path}")
|
|
|
|
except Exception as e:
|
|
if os.path.exists(temp_output):
|
|
os.remove(temp_output)
|
|
logging.error(f"merge_video_audio failed with error: {e}")
|
|
|
|
|
|
def save_video(tensor,
|
|
save_file=None,
|
|
fps=30,
|
|
suffix='.mp4',
|
|
nrow=8,
|
|
normalize=True,
|
|
value_range=(-1, 1)):
|
|
# cache file
|
|
cache_file = osp.join('/tmp', rand_name(
|
|
suffix=suffix)) if save_file is None else save_file
|
|
|
|
# save to cache
|
|
try:
|
|
# preprocess
|
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
tensor = torch.stack([
|
|
torchvision.utils.make_grid(
|
|
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
|
for u in tensor.unbind(2)
|
|
],
|
|
dim=1).permute(1, 2, 3, 0)
|
|
tensor = (tensor * 255).type(torch.uint8).cpu()
|
|
|
|
# write video
|
|
writer = imageio.get_writer(
|
|
cache_file, fps=fps, codec='libx264', quality=8)
|
|
for frame in tensor.numpy():
|
|
writer.append_data(frame)
|
|
writer.close()
|
|
except Exception as e:
|
|
logging.info(f'save_video failed, error: {e}')
|
|
|
|
|
|
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
|
|
# cache file
|
|
suffix = osp.splitext(save_file)[1]
|
|
if suffix.lower() not in [
|
|
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
|
]:
|
|
suffix = '.png'
|
|
|
|
# save to cache
|
|
try:
|
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
torchvision.utils.save_image(
|
|
tensor,
|
|
save_file,
|
|
nrow=nrow,
|
|
normalize=normalize,
|
|
value_range=value_range)
|
|
return save_file
|
|
except Exception as e:
|
|
logging.info(f'save_image failed, error: {e}')
|
|
|
|
|
|
def str2bool(v):
|
|
"""
|
|
Convert a string to a boolean.
|
|
|
|
Supported true values: 'yes', 'true', 't', 'y', '1'
|
|
Supported false values: 'no', 'false', 'f', 'n', '0'
|
|
|
|
Args:
|
|
v (str): String to convert.
|
|
|
|
Returns:
|
|
bool: Converted boolean value.
|
|
|
|
Raises:
|
|
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
v_lower = v.lower()
|
|
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
|
|
|
|
|
def masks_like(tensor, zero=False, generator=None, p=0.2):
|
|
assert isinstance(tensor, list)
|
|
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
|
|
|
|
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
|
|
|
|
if zero:
|
|
if generator is not None:
|
|
for u, v in zip(out1, out2):
|
|
random_num = torch.rand(
|
|
1, generator=generator, device=generator.device).item()
|
|
if random_num < p:
|
|
u[:, 0] = torch.normal(
|
|
mean=-3.5,
|
|
std=0.5,
|
|
size=(1,),
|
|
device=u.device,
|
|
generator=generator).expand_as(u[:, 0]).exp()
|
|
v[:, 0] = torch.zeros_like(v[:, 0])
|
|
else:
|
|
u[:, 0] = u[:, 0]
|
|
v[:, 0] = v[:, 0]
|
|
else:
|
|
for u, v in zip(out1, out2):
|
|
u[:, 0] = torch.zeros_like(u[:, 0])
|
|
v[:, 0] = torch.zeros_like(v[:, 0])
|
|
|
|
return out1, out2
|
|
|
|
|
|
def best_output_size(w, h, dw, dh, expected_area):
|
|
# float output size
|
|
ratio = w / h
|
|
ow = (expected_area * ratio)**0.5
|
|
oh = expected_area / ow
|
|
|
|
# process width first
|
|
ow1 = int(ow // dw * dw)
|
|
oh1 = int(expected_area / ow1 // dh * dh)
|
|
assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area
|
|
ratio1 = ow1 / oh1
|
|
|
|
# process height first
|
|
oh2 = int(oh // dh * dh)
|
|
ow2 = int(expected_area / oh2 // dw * dw)
|
|
assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area
|
|
ratio2 = ow2 / oh2
|
|
|
|
# compare ratios
|
|
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2,
|
|
ratio2 / ratio):
|
|
return ow1, oh1
|
|
else:
|
|
return ow2, oh2
|
|
|
|
|
|
def download_cosyvoice_repo(repo_path):
|
|
try:
|
|
import git
|
|
except ImportError:
|
|
raise ImportError('failed to import git, please run pip install GitPython')
|
|
repo = git.Repo.clone_from('https://github.com/FunAudioLLM/CosyVoice.git', repo_path, multi_options=['--recursive'], branch='main')
|
|
|
|
|
|
def download_cosyvoice_model(model_name, model_path):
|
|
from modelscope import snapshot_download
|
|
snapshot_download('iic/{}'.format(model_name), local_dir=model_path)
|