1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/framepack/framepack_worker.py
Vladimir Mandic 0faf61f48a video tab add params.txt
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-22 08:39:45 -04:00

340 lines
19 KiB
Python

import time
import torch
import rich.progress as rp
from modules import shared, errors ,devices, sd_models, timer, memstats
from modules.framepack import framepack_vae # pylint: disable=wrong-import-order
from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order
from modules.video_models.video_save import save_video # pylint: disable=wrong-import-order
stream = None # AsyncStream
def get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant):
try:
real_fps = mp4_fps / (mp4_interpolate + 1)
is_f1 = variant == 'forward-only'
if is_f1:
total_latent_sections = (total_second_length * real_fps) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))
latent_paddings = list(range(total_latent_sections))
else:
total_latent_sections = int(max((total_second_length * real_fps) / (latent_window_size * 4), 1))
latent_paddings = list(reversed(range(total_latent_sections)))
if total_latent_sections > 4: # extra padding for better quality
# latent_paddings = list(reversed(range(total_latent_sections)))
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
except Exception:
latent_paddings = [0]
return latent_paddings
def worker(
input_image, end_image,
start_weight, end_weight, vision_weight,
prompts, n_prompt, system_prompt, optimized_prompt, unmodified_prompt,
seed,
total_second_length,
latent_window_size,
steps,
cfg_scale, cfg_distilled, cfg_rescale,
shift,
use_teacache, use_cfgzero, use_preview,
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
vae_type,
variant,
metadata:dict={},
):
timer.process.reset()
memstats.reset_stats()
if stream is None or shared.state.interrupted or shared.state.skipped:
shared.log.error('FramePack: stream is None')
stream.output_queue.push(('end', None))
return
from modules.framepack.pipeline import hunyuan
from modules.framepack.pipeline import utils
from modules.framepack.pipeline import k_diffusion_hunyuan
is_f1 = variant == 'forward-only'
total_generated_frames = 0
total_generated_latent_frames = 0
latent_paddings = get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant)
num_frames = latent_window_size * 4 - 3 # number of frames to generate in each section
metadata['title'] = 'sdnext framepack'
metadata['description'] = f'variant:{variant} seed:{seed} steps:{steps} scale:{cfg_scale} distilled:{cfg_distilled} rescale:{cfg_rescale} shift:{shift} start:{start_weight} end:{end_weight} vision:{vision_weight}'
videojob = shared.state.begin('Video')
shared.state.job_count = 1
text_encoder = shared.sd_model.text_encoder
text_encoder_2 = shared.sd_model.text_encoder_2
tokenizer = shared.sd_model.tokenizer
tokenizer_2 = shared.sd_model.tokenizer_2
feature_extractor = shared.sd_model.feature_extractor
image_encoder = shared.sd_model.image_processor
transformer = shared.sd_model.transformer
sd_models.apply_balanced_offload(shared.sd_model)
pbar = rp.Progress(rp.TextColumn('[cyan]Video'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
task = pbar.add_task('starting', total=steps * len(latent_paddings))
t_last = time.time()
if not is_f1:
prompts = list(reversed(prompts))
def text_encode(prompt, i:int=None):
jobid = shared.state.begin('TE Encode')
pbar.update(task, description=f'text encode section={i}')
t0 = time.time()
torch.manual_seed(seed)
# shared.log.debug(f'FramePack: section={i} prompt="{prompt}"')
shared.state.textinfo = 'Text encode'
stream.output_queue.push(('progress', (None, 'Text encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt)
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
metadata['comment'] = prompt
if cfg_scale > 1 and n_prompt is not None and len(n_prompt) > 0:
llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
else:
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512)
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('prompt', time.time()-t0)
shared.state.end(jobid)
return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n
def latents_encode(input_image, end_image):
jobid = shared.state.begin('VAE Encode')
pbar.update(task, description='image encode')
# shared.log.debug(f'FramePack: image encode init={input_image.shape} end={end_image.shape if end_image is not None else None}')
t0 = time.time()
torch.manual_seed(seed)
stream.output_queue.push(('progress', (None, 'VAE encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
if input_image is not None:
input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
start_latent = framepack_vae.vae_encode(input_image_pt)
if start_weight < 1:
noise = torch.randn_like(start_latent)
start_latent = start_latent * start_weight + noise * (1 - start_weight)
if end_image is not None:
end_image_pt = torch.from_numpy(end_image).float() / 127.5 - 1
end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None]
end_latent = framepack_vae.vae_encode(end_image_pt)
else:
end_latent = None
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('encode', time.time()-t0)
shared.state.end(jobid)
return start_latent, end_latent
def vision_encode(input_image, end_image):
pbar.update(task, description='vision encode')
# shared.log.debug(f'FramePack: vision encode init={input_image.shape} end={end_image.shape if end_image is not None else None}')
t0 = time.time()
shared.state.textinfo = 'Vision encode'
stream.output_queue.push(('progress', (None, 'Vision encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
# siglip doesn't work with offload
sd_models.move_model(feature_extractor, devices.device, force=True)
sd_models.move_model(image_encoder, devices.device, force=True)
preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
image_encoder_output = image_encoder(**preprocessed)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
if end_image is not None:
preprocessed = feature_extractor.preprocess(images=end_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
end_image_encoder_output = image_encoder(**preprocessed)
end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach
image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('vision', time.time()-t0)
return image_encoder_last_hidden_state
def step_callback(d):
if use_cfgzero and is_first_section and d['i'] == 0:
d['denoised'] = d['denoised'] * 0
t_current = time.time()
if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped:
stream.output_queue.push(('progress', (None, 'Interrupted...')))
stream.output_queue.push(('end', None))
raise AssertionError('Interrupted...')
if shared.state.paused:
shared.log.debug('Sampling paused')
while shared.state.paused:
if shared.state.interrupted or shared.state.skipped:
raise AssertionError('Interrupted...')
time.sleep(0.1)
nonlocal total_generated_frames, t_last
t_preview = time.time()
current_step = d['i'] + 1
shared.state.textinfo = ''
shared.state.sampling_step = ((lattent_padding_loop-1) * steps) + current_step
shared.state.sampling_steps = steps * len(latent_paddings)
progress = shared.state.sampling_step / shared.state.sampling_steps
total_generated_frames = int(max(0, total_generated_latent_frames * 4 - 3))
pbar.update(task, advance=1, description=f'its={1/(t_current-t_last):.2f} sample={d["i"]+1}/{steps} section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)}')
desc = f'Step {shared.state.sampling_step}/{shared.state.sampling_steps} | Current {current_step}/{steps} | Section {lattent_padding_loop}/{len(latent_paddings)} | Progress {progress:.2%}'
if use_preview:
preview = framepack_vae.vae_decode(d['denoised'], 'Preview')
stream.output_queue.push(('progress', (preview, desc)))
else:
stream.output_queue.push(('progress', (None, desc)))
timer.process.add('preview', time.time() - t_preview)
t_last = t_current
try:
with devices.inference_context(), pbar:
t0 = time.time()
height, width, _C = input_image.shape
start_latent, end_latent = latents_encode(input_image, end_image)
image_encoder_last_hidden_state = vision_encode(input_image, end_image)
# Sample loop
stream.output_queue.push(('progress', (None, 'Start sampling...')))
generator = torch.Generator("cpu").manual_seed(seed)
if is_f1:
history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu()
else:
history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=devices.dtype).cpu()
history_pixels = None
lattent_padding_loop = 0
last_prompt = None
for latent_padding in latent_paddings:
current_prompt = prompts[lattent_padding_loop]
if current_prompt != last_prompt:
llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n = text_encode(current_prompt, i=lattent_padding_loop+1)
last_prompt = current_prompt
sammplejob = shared.state.begin('Sample')
lattent_padding_loop += 1
# shared.log.trace(f'FramePack: op=sample section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)} window={latent_window_size} size={num_frames}')
if is_f1:
is_first_section, is_last_section = False, False
else:
is_first_section, is_last_section = latent_padding == latent_paddings[0], latent_padding == 0
if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped:
stream.output_queue.push(('end', None))
return
if is_f1:
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
else:
latent_padding_size = latent_padding * latent_window_size
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
clean_latent_indices_pre, _blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
clean_latents_pre = start_latent.to(history_latents)
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
if end_image is not None and is_first_section:
clean_latents_post = (clean_latents_post * start_weight / len(latent_paddings)) + (end_weight * end_latent.to(history_latents)) / (start_weight/len(latent_paddings) + end_weight) # pylint: disable=possibly-used-before-assignment
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
sd_models.apply_balanced_offload(shared.sd_model)
transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps, rel_l1_thresh=shared.opts.teacache_thresh)
t_sample = time.time()
generated_latents = k_diffusion_hunyuan.sample_hunyuan(
transformer=transformer,
sampler='unipc',
width=width,
height=height,
frames=num_frames,
num_inference_steps=steps,
real_guidance_scale=cfg_scale,
distilled_guidance_scale=cfg_distilled,
guidance_rescale=cfg_rescale,
shift=shift if shift > 0 else None,
generator=generator,
prompt_embeds=llama_vec, # pylint: disable=possibly-used-before-assignment
prompt_embeds_mask=llama_attention_mask, # pylint: disable=possibly-used-before-assignment
prompt_poolers=clip_l_pooler, # pylint: disable=possibly-used-before-assignment
negative_prompt_embeds=llama_vec_n, # pylint: disable=possibly-used-before-assignment
negative_prompt_embeds_mask=llama_attention_mask_n, # pylint: disable=possibly-used-before-assignment
negative_prompt_poolers=clip_l_pooler_n, # pylint: disable=possibly-used-before-assignment
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
device=devices.device,
dtype=devices.dtype,
callback=step_callback,
)
if is_last_section:
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
total_generated_latent_frames += int(generated_latents.shape[2])
if is_f1:
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
else:
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('sample', time.time()-t_sample)
shared.state.end(sammplejob)
t_vae = time.time()
if history_pixels is None:
history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu()
else:
overlapped_frames = latent_window_size * 4 - 3
if is_f1:
section_latent_frames = latent_window_size * 2
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, -section_latent_frames:], vae_type=vae_type).cpu()
history_pixels = utils.soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
else:
section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu()
history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('vae', time.time()-t_vae)
if is_last_section:
break
total_generated_frames, _video_filename = save_video(
None,
history_pixels,
mp4_fps,
mp4_codec,
mp4_opt,
mp4_ext,
mp4_sf,
mp4_video,
mp4_frames,
mp4_interpolate,
pbar=pbar,
stream=stream,
metadata=metadata,
)
except AssertionError:
shared.log.info('FramePack: interrupted')
if shared.opts.keep_incomplete:
save_video(None, history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate=0, stream=stream, metadata=metadata)
except Exception as e:
shared.log.error(f'FramePack: {e}')
errors.display(e, 'FramePack')
sd_models.apply_balanced_offload(shared.sd_model)
stream.output_queue.push(('end', None))
t1 = time.time()
shared.log.info(f'Processed: frames={total_generated_frames} fps={total_generated_frames/(t1-t0):.2f} its={(shared.state.sampling_step)/(t1-t0):.2f} time={t1-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
shared.state.end(videojob)