You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
handle controlnet better
This commit is contained in:
@@ -904,7 +904,7 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None,
|
|||||||
if cnt % 100 == 0:
|
if cnt % 100 == 0:
|
||||||
pbar.update(100)
|
pbar.update(100)
|
||||||
|
|
||||||
[print(name, param.device, param.dtype) for name, param in transformer.named_parameters()]
|
#[print(name, param.device, param.dtype) for name, param in transformer.named_parameters()]
|
||||||
|
|
||||||
pbar.update_absolute(0)
|
pbar.update_absolute(0)
|
||||||
|
|
||||||
|
|||||||
@@ -1190,6 +1190,12 @@ class WanVideoSampler:
|
|||||||
log.info(f" {k}: {v.shape if isinstance(v, torch.Tensor) else v}")
|
log.info(f" {k}: {v.shape if isinstance(v, torch.Tensor) else v}")
|
||||||
one_to_all_data = one_to_all_embeds.copy()
|
one_to_all_data = one_to_all_embeds.copy()
|
||||||
one_to_all_data = dict_to_device(one_to_all_data, device, dtype)
|
one_to_all_data = dict_to_device(one_to_all_data, device, dtype)
|
||||||
|
if one_to_all_embeds.get("pose_images") is not None:
|
||||||
|
pose_images = transformer.input_hint_block(one_to_all_data.pop("pose_images"))
|
||||||
|
if one_to_all_embeds.get("ref_latent_pos") is not None:
|
||||||
|
pose_prefix_image = transformer.input_hint_block(one_to_all_data.pop("pose_prefix_image"))
|
||||||
|
pose_images = torch.cat([pose_prefix_image, pose_images],dim=2)
|
||||||
|
one_to_all_data["controlnet_tokens"] = pose_images.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
#region model pred
|
#region model pred
|
||||||
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
|
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
|
||||||
|
|||||||
@@ -2308,18 +2308,11 @@ class WanModel(torch.nn.Module):
|
|||||||
suffix_frames += 1
|
suffix_frames += 1
|
||||||
onetoall_ref_block_samples, onetoall_freqs = self.refextractor(ref_cond_latent, timestep=t)
|
onetoall_ref_block_samples, onetoall_freqs = self.refextractor(ref_cond_latent, timestep=t)
|
||||||
|
|
||||||
controlnet_video = one_to_all_input.get("pose_images")
|
|
||||||
controlnet_tokens = None
|
|
||||||
# pose controlnet
|
# pose controlnet
|
||||||
if controlnet_video is not None and one_to_all_input['controlnet_start_percent'] <= current_step_percentage <= one_to_all_input['controlnet_end_percent']:
|
controlnet_tokens = one_to_all_input.get("controlnet_tokens", None)
|
||||||
|
if controlnet_tokens is not None and one_to_all_input['controlnet_start_percent'] <= current_step_percentage <= one_to_all_input['controlnet_end_percent']:
|
||||||
onetoall_control_enabled = True
|
onetoall_control_enabled = True
|
||||||
onetoall_control_strength = one_to_all_input.get("controlnet_strength", 1.0)
|
onetoall_control_strength = one_to_all_input.get("controlnet_strength", 1.0)
|
||||||
image_pose = one_to_all_input.get("pose_prefix_image")
|
|
||||||
controlnet_video = self.input_hint_block(controlnet_video)
|
|
||||||
controlnet_prefix = self.input_hint_block(image_pose)
|
|
||||||
if ref_cond_latent is not None:
|
|
||||||
controlnet_video = torch.cat([controlnet_prefix, controlnet_video],dim=2)
|
|
||||||
controlnet_tokens = controlnet_video.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
#uni3c controlnet
|
#uni3c controlnet
|
||||||
if uni3c_data is not None:
|
if uni3c_data is not None:
|
||||||
@@ -3003,20 +2996,31 @@ class WanModel(torch.nn.Module):
|
|||||||
#controlnet
|
#controlnet
|
||||||
if (controlnet is not None) and (b % controlnet["controlnet_stride"] == 0) and (b // controlnet["controlnet_stride"] < len(controlnet["controlnet_states"])):
|
if (controlnet is not None) and (b % controlnet["controlnet_stride"] == 0) and (b // controlnet["controlnet_stride"] < len(controlnet["controlnet_states"])):
|
||||||
x[:, :self.original_seq_len] += controlnet["controlnet_states"][b // controlnet["controlnet_stride"]].to(x) * controlnet["controlnet_weight"]
|
x[:, :self.original_seq_len] += controlnet["controlnet_states"][b // controlnet["controlnet_stride"]].to(x) * controlnet["controlnet_weight"]
|
||||||
# One-to-All controlnet
|
# One-to-All-Animation controlnet
|
||||||
if onetoall_control_enabled:
|
if onetoall_control_enabled:
|
||||||
if prev_x is not None and (b - 1) < len(self.controlnet.blocks):
|
if prev_x is not None and (b - 1) < len(self.controlnet.blocks):
|
||||||
|
tqdm.write(f"Applying One-to-All ControlNet at block {b}")
|
||||||
if b == 1:
|
if b == 1:
|
||||||
ctrl_in = prev_x + controlnet_tokens
|
ctrl_in = prev_x + controlnet_tokens
|
||||||
elif prev_control is not None:
|
elif prev_control is not None:
|
||||||
ctrl_in = prev_control
|
ctrl_in = prev_control
|
||||||
|
|
||||||
|
self.controlnet.blocks[b - 1].to(self.main_device)
|
||||||
control_out = self.controlnet.blocks[b - 1](ctrl_in, e0, seq_lens, freqs, split_rope=False)
|
control_out = self.controlnet.blocks[b - 1](ctrl_in, e0, seq_lens, freqs, split_rope=False)
|
||||||
|
self.controlnet.blocks[b - 1].to(self.offload_device, non_blocking=self.use_non_blocking)
|
||||||
prev_control = control_out
|
prev_control = control_out
|
||||||
|
|
||||||
control_out_proj = self.controlnet_zero[b - 1](control_out)
|
control_out_proj = self.controlnet_zero[b - 1](control_out)
|
||||||
x = x + control_out_proj * onetoall_control_strength
|
x = x + control_out_proj * onetoall_control_strength
|
||||||
prev_x = x
|
if b < len(self.controlnet.blocks): # Store prev_x only while controlnet is active
|
||||||
|
prev_x = x
|
||||||
|
elif b == len(self.controlnet.blocks): # Controlnet done, free memory
|
||||||
|
prev_x = None
|
||||||
|
prev_control = None
|
||||||
|
if controlnet_tokens is not None:
|
||||||
|
del controlnet_tokens
|
||||||
|
controlnet_tokens = None
|
||||||
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
if lynx_ref_feature_extractor:
|
if lynx_ref_feature_extractor:
|
||||||
return lynx_ref_buffer
|
return lynx_ref_buffer
|
||||||
|
|||||||
Reference in New Issue
Block a user