1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/ltx/ltx_process.py
Vladimir Mandic 05b919f35b ltx2 improvements
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2026-01-16 17:31:56 +00:00

321 lines
12 KiB
Python

import os
import time
import torch
from PIL import Image
from modules import shared, errors, timer, memstats, progress, processing, sd_models, sd_samplers, extra_networks, call_queue
from modules.video_models.video_vae import set_vae_params
from modules.video_models.video_save import save_video
from modules.video_models.video_utils import check_av
from modules.processing_callbacks import diffusers_callback
from modules.ltx.ltx_util import get_bucket, get_frames, load_model, load_upsample, get_conditions, get_generator, get_prompts, vae_decode
debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None
# engine, model = 'LTX Video', 'LTXVideo 0.9.7 13B'
upsample_repo_id = "a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers"
upsample_pipe = None
def run_ltx(task_id,
_ui_state,
model:str,
prompt:str,
negative:str,
styles:list[str],
width:int,
height:int,
frames:int,
steps:int,
sampler_index:int,
seed:int,
upsample_enable:bool,
upsample_ratio:float,
refine_enable:bool,
refine_strength:float,
condition_strength: float,
condition_image,
condition_last,
condition_files,
condition_video,
condition_video_frames:int,
condition_video_skip:int,
decode_timestep:float,
image_cond_noise_scale:float,
mp4_fps:int,
mp4_interpolate:int,
mp4_codec:str,
mp4_ext:str,
mp4_opt:str,
mp4_video:bool,
mp4_frames:bool,
mp4_sf:bool,
audio_enable:bool,
_overrides,
):
def abort(e, ok:bool=False, p=None):
if ok:
shared.log.info(e)
else:
shared.log.error(f'Video: cls={shared.sd_model.__class__.__name__} op=base {e}')
errors.display(e, 'LTX')
if p is not None:
extra_networks.deactivate(p)
shared.state.end()
progress.finish_task(task_id)
yield None, f'LTX Error: {str(e)}'
if model is None or len(model) == 0:
yield from abort('Video: no model selected', ok=True)
return
# from diffusers import LTXConditionPipeline # pylint: disable=unused-import
check_av()
progress.add_task_to_queue(task_id)
with call_queue.get_lock():
progress.start_task(task_id)
memstats.reset_stats()
timer.process.reset()
yield None, 'LTX: Loading...'
engine = 'LTX Video'
load_model(engine, model)
debug(f'Video: cls={shared.sd_model.__class__.__name__} op=init model="{model}"')
if not shared.sd_model.__class__.__name__.startswith("LTX"):
yield from abort(f'Video: cls={shared.sd_model.__class__.__name__} selected model is not LTX model', ok=True)
return
videojob = shared.state.begin('Video', task_id=task_id)
shared.state.job_count = 1
p = processing.StableDiffusionProcessingVideo(
video_engine=engine,
video_model=model,
prompt=prompt,
negative_prompt=negative,
styles=styles,
width=width,
height=height,
frames=frames,
steps=steps,
sampler_index=sampler_index,
seed=seed,
)
p.ops.append('video')
condition_images = []
if condition_image is not None:
condition_images.append(condition_image)
if condition_last is not None:
condition_images.append(condition_last)
conditions = get_conditions(
width,
height,
condition_strength,
condition_images,
condition_files,
condition_video,
condition_video_frames,
condition_video_skip,
)
prompt, negative, networks = get_prompts(prompt, negative, styles)
sampler_name = processing.get_sampler_name(sampler_index)
sd_samplers.create_sampler(sampler_name, shared.sd_model)
shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} op=init styles={styles} networks={networks} sampler={shared.sd_model.scheduler.__class__.__name__}')
extra_networks.activate(p, networks)
framewise = 'LTX2' not in shared.sd_model.__class__.__name__
set_vae_params(p, framewise=framewise)
t0 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t1 = time.time()
if 'LTX2' in shared.sd_model.__class__.__name__:
output_type = 'np'
else:
output_type = 'latent'
base_args = {
"prompt": prompt,
"negative_prompt": negative,
"width": get_bucket(width),
"height": get_bucket(height),
"num_frames": get_frames(frames),
"num_inference_steps": steps,
"generator": get_generator(seed),
"callback_on_step_end": diffusers_callback,
"output_type": output_type,
}
if 'LTX2' in shared.sd_model.__class__.__name__:
base_args["frame_rate"] = float(mp4_fps)
if 'Condition' in shared.sd_model.__class__.__name__:
base_args["image_cond_noise_scale"] = image_cond_noise_scale
shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} op=base {base_args}')
if len(conditions) > 0:
base_args["conditions"] = conditions
if debug:
shared.log.trace(f'LTX args: {base_args}')
yield None, 'LTX: Generate in progress...'
samplejob = shared.state.begin('Sample')
try:
result = shared.sd_model(**base_args)
latents = result.frames[0]
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
if audio_enable and hasattr(result, 'audio') and result.audio is not None:
audio = result.audio[0].float().cpu()
else:
audio = None
try:
if debug:
shared.log.trace(f'LTX result frames={latents.shape if latents is not None else None} audio={audio.shape if audio is not None else None}')
except Exception:
pass
t2 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t3 = time.time()
timer.process.add('offload', t1 - t0)
timer.process.add('base', t2 - t1)
timer.process.add('offload', t3 - t2)
shared.state.end(samplejob)
if upsample_enable:
t4 = time.time()
upsamplejob = shared.state.begin('Upsample')
global upsample_pipe # pylint: disable=global-statement
upsample_pipe = load_upsample(upsample_pipe, upsample_repo_id)
upsample_pipe = sd_models.apply_balanced_offload(upsample_pipe)
upscale_args = {
"width": get_bucket(upsample_ratio * width),
"height": get_bucket(upsample_ratio * height),
"generator": get_generator(seed),
"output_type": output_type,
}
if latents.ndim == 4:
latents = latents.unsqueeze(0) # add batch dimension
shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} op=upsample latents={latents.shape} {upscale_args}')
yield None, 'LTX: Upsample in progress...'
try:
upsampled_latents = upsample_pipe(latents=latents, **upscale_args).frames[0]
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
latents = upsampled_latents
t5 = time.time()
upsample_pipe = sd_models.apply_balanced_offload(upsample_pipe)
t6 = time.time()
timer.process.add('upsample', t5 - t4)
timer.process.add('offload', t6 - t5)
shared.state.end(upsamplejob)
if refine_enable:
t7 = time.time()
refinejob = shared.state.begin('Refine')
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
refine_args = {
"prompt": prompt,
"negative_prompt": negative,
"width": get_bucket(upsample_ratio * width),
"height": get_bucket(upsample_ratio * height),
"num_frames": get_frames(frames),
"denoise_strength": refine_strength,
"num_inference_steps": steps,
"image_cond_noise_scale": image_cond_noise_scale,
"generator": get_generator(seed),
"callback_on_step_end": diffusers_callback,
"output_type": output_type,
}
if latents.ndim == 4:
latents = latents.unsqueeze(0) # add batch dimension
shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} op=refine latents={latents.shape} {refine_args}')
if len(conditions) > 0:
refine_args["conditions"] = conditions
yield None, 'LTX: Refine in progress...'
try:
refined_latents = shared.sd_model(latents=latents, **refine_args).frames[0]
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
latents = refined_latents
t8 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t9 = time.time()
timer.process.add('refine', t8 - t7)
timer.process.add('offload', t9 - t8)
shared.state.end(refinejob)
extra_networks.deactivate(p)
yield None, 'LTX: VAE decode in progress...'
try:
if torch.is_tensor(latents):
frames = vae_decode(latents, decode_timestep, seed)
else:
frames = latents
except TypeError:
frames = latents # likely because the latents are already decoded
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
t10 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t11 = time.time()
timer.process.add('offload', t11 - t10)
try:
aac_sample_rate = shared.sd_model.vocoder.config.output_sampling_rate
except Exception:
aac_sample_rate = 24000
num_frames, video_file = save_video(
p=p,
pixels=frames,
audio=audio,
mp4_fps=mp4_fps,
mp4_codec=mp4_codec,
mp4_opt=mp4_opt,
mp4_ext=mp4_ext,
mp4_sf=mp4_sf,
mp4_video=mp4_video,
mp4_frames=mp4_frames,
mp4_interpolate=mp4_interpolate,
aac_sample_rate=aac_sample_rate,
metadata={},
)
t_end = time.time()
if isinstance(frames, list) and isinstance(frames[0], Image.Image):
w, h = frames[0].size
elif frames.ndim == 5:
_n, _c, _t, h, w = frames.shape
elif frames.ndim == 4:
_n, h, w, _c = frames.shape
else:
h, w = frames.shape[-2], frames.shape[-1]
resolution = f'{w}x{h}' if num_frames > 0 else None
summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
memory = shared.mem_mon.summary()
fps = f'{num_frames/(t_end-t0):.2f}'
its = f'{(steps)/(t_end-t0):.2f}'
shared.state.end(videojob)
progress.finish_task(task_id)
shared.log.info(f'Processed: fn="{video_file}" frames={num_frames} fps={fps} its={its} resolution={resolution} time={t_end-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
yield video_file, f'LTX: Generation completed | File {video_file} | Frames {len(frames)} | Resolution {resolution} | f/s {fps} | it/s {its} '+ f"<div class='performance'><p>{summary} {memory}</p></div>"