1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

handle controlnet better

This commit is contained in:
kijai
2025-12-07 00:27:23 +02:00
parent c5742552a9
commit 3e4e4db35d
3 changed files with 22 additions and 12 deletions

View File

@@ -904,7 +904,7 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None,
if cnt % 100 == 0: if cnt % 100 == 0:
pbar.update(100) pbar.update(100)
[print(name, param.device, param.dtype) for name, param in transformer.named_parameters()] #[print(name, param.device, param.dtype) for name, param in transformer.named_parameters()]
pbar.update_absolute(0) pbar.update_absolute(0)

View File

@@ -1190,6 +1190,12 @@ class WanVideoSampler:
log.info(f" {k}: {v.shape if isinstance(v, torch.Tensor) else v}") log.info(f" {k}: {v.shape if isinstance(v, torch.Tensor) else v}")
one_to_all_data = one_to_all_embeds.copy() one_to_all_data = one_to_all_embeds.copy()
one_to_all_data = dict_to_device(one_to_all_data, device, dtype) one_to_all_data = dict_to_device(one_to_all_data, device, dtype)
if one_to_all_embeds.get("pose_images") is not None:
pose_images = transformer.input_hint_block(one_to_all_data.pop("pose_images"))
if one_to_all_embeds.get("ref_latent_pos") is not None:
pose_prefix_image = transformer.input_hint_block(one_to_all_data.pop("pose_prefix_image"))
pose_images = torch.cat([pose_prefix_image, pose_images],dim=2)
one_to_all_data["controlnet_tokens"] = pose_images.flatten(2).transpose(1, 2)
#region model pred #region model pred
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,

View File

@@ -2308,18 +2308,11 @@ class WanModel(torch.nn.Module):
suffix_frames += 1 suffix_frames += 1
onetoall_ref_block_samples, onetoall_freqs = self.refextractor(ref_cond_latent, timestep=t) onetoall_ref_block_samples, onetoall_freqs = self.refextractor(ref_cond_latent, timestep=t)
controlnet_video = one_to_all_input.get("pose_images")
controlnet_tokens = None
# pose controlnet # pose controlnet
if controlnet_video is not None and one_to_all_input['controlnet_start_percent'] <= current_step_percentage <= one_to_all_input['controlnet_end_percent']: controlnet_tokens = one_to_all_input.get("controlnet_tokens", None)
if controlnet_tokens is not None and one_to_all_input['controlnet_start_percent'] <= current_step_percentage <= one_to_all_input['controlnet_end_percent']:
onetoall_control_enabled = True onetoall_control_enabled = True
onetoall_control_strength = one_to_all_input.get("controlnet_strength", 1.0) onetoall_control_strength = one_to_all_input.get("controlnet_strength", 1.0)
image_pose = one_to_all_input.get("pose_prefix_image")
controlnet_video = self.input_hint_block(controlnet_video)
controlnet_prefix = self.input_hint_block(image_pose)
if ref_cond_latent is not None:
controlnet_video = torch.cat([controlnet_prefix, controlnet_video],dim=2)
controlnet_tokens = controlnet_video.flatten(2).transpose(1, 2)
#uni3c controlnet #uni3c controlnet
if uni3c_data is not None: if uni3c_data is not None:
@@ -3003,20 +2996,31 @@ class WanModel(torch.nn.Module):
#controlnet #controlnet
if (controlnet is not None) and (b % controlnet["controlnet_stride"] == 0) and (b // controlnet["controlnet_stride"] < len(controlnet["controlnet_states"])): if (controlnet is not None) and (b % controlnet["controlnet_stride"] == 0) and (b // controlnet["controlnet_stride"] < len(controlnet["controlnet_states"])):
x[:, :self.original_seq_len] += controlnet["controlnet_states"][b // controlnet["controlnet_stride"]].to(x) * controlnet["controlnet_weight"] x[:, :self.original_seq_len] += controlnet["controlnet_states"][b // controlnet["controlnet_stride"]].to(x) * controlnet["controlnet_weight"]
# One-to-All controlnet # One-to-All-Animation controlnet
if onetoall_control_enabled: if onetoall_control_enabled:
if prev_x is not None and (b - 1) < len(self.controlnet.blocks): if prev_x is not None and (b - 1) < len(self.controlnet.blocks):
tqdm.write(f"Applying One-to-All ControlNet at block {b}")
if b == 1: if b == 1:
ctrl_in = prev_x + controlnet_tokens ctrl_in = prev_x + controlnet_tokens
elif prev_control is not None: elif prev_control is not None:
ctrl_in = prev_control ctrl_in = prev_control
self.controlnet.blocks[b - 1].to(self.main_device)
control_out = self.controlnet.blocks[b - 1](ctrl_in, e0, seq_lens, freqs, split_rope=False) control_out = self.controlnet.blocks[b - 1](ctrl_in, e0, seq_lens, freqs, split_rope=False)
self.controlnet.blocks[b - 1].to(self.offload_device, non_blocking=self.use_non_blocking)
prev_control = control_out prev_control = control_out
control_out_proj = self.controlnet_zero[b - 1](control_out) control_out_proj = self.controlnet_zero[b - 1](control_out)
x = x + control_out_proj * onetoall_control_strength x = x + control_out_proj * onetoall_control_strength
prev_x = x if b < len(self.controlnet.blocks): # Store prev_x only while controlnet is active
prev_x = x
elif b == len(self.controlnet.blocks): # Controlnet done, free memory
prev_x = None
prev_control = None
if controlnet_tokens is not None:
del controlnet_tokens
controlnet_tokens = None
mm.soft_empty_cache()
if lynx_ref_feature_extractor: if lynx_ref_feature_extractor:
return lynx_ref_buffer return lynx_ref_buffer