diff --git a/nodes_model_loading.py b/nodes_model_loading.py index fb109bd..54306f6 100644 --- a/nodes_model_loading.py +++ b/nodes_model_loading.py @@ -904,7 +904,7 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None, if cnt % 100 == 0: pbar.update(100) - [print(name, param.device, param.dtype) for name, param in transformer.named_parameters()] + #[print(name, param.device, param.dtype) for name, param in transformer.named_parameters()] pbar.update_absolute(0) diff --git a/nodes_sampler.py b/nodes_sampler.py index f323503..ff5d40f 100644 --- a/nodes_sampler.py +++ b/nodes_sampler.py @@ -1190,6 +1190,12 @@ class WanVideoSampler: log.info(f" {k}: {v.shape if isinstance(v, torch.Tensor) else v}") one_to_all_data = one_to_all_embeds.copy() one_to_all_data = dict_to_device(one_to_all_data, device, dtype) + if one_to_all_embeds.get("pose_images") is not None: + pose_images = transformer.input_hint_block(one_to_all_data.pop("pose_images")) + if one_to_all_embeds.get("ref_latent_pos") is not None: + pose_prefix_image = transformer.input_hint_block(one_to_all_data.pop("pose_prefix_image")) + pose_images = torch.cat([pose_prefix_image, pose_images],dim=2) + one_to_all_data["controlnet_tokens"] = pose_images.flatten(2).transpose(1, 2) #region model pred def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, diff --git a/wanvideo/modules/model.py b/wanvideo/modules/model.py index db2b73b..692400a 100644 --- a/wanvideo/modules/model.py +++ b/wanvideo/modules/model.py @@ -2308,18 +2308,11 @@ class WanModel(torch.nn.Module): suffix_frames += 1 onetoall_ref_block_samples, onetoall_freqs = self.refextractor(ref_cond_latent, timestep=t) - controlnet_video = one_to_all_input.get("pose_images") - controlnet_tokens = None # pose controlnet - if controlnet_video is not None and one_to_all_input['controlnet_start_percent'] <= current_step_percentage <= one_to_all_input['controlnet_end_percent']: + controlnet_tokens = one_to_all_input.get("controlnet_tokens", None) + if controlnet_tokens is not None and one_to_all_input['controlnet_start_percent'] <= current_step_percentage <= one_to_all_input['controlnet_end_percent']: onetoall_control_enabled = True onetoall_control_strength = one_to_all_input.get("controlnet_strength", 1.0) - image_pose = one_to_all_input.get("pose_prefix_image") - controlnet_video = self.input_hint_block(controlnet_video) - controlnet_prefix = self.input_hint_block(image_pose) - if ref_cond_latent is not None: - controlnet_video = torch.cat([controlnet_prefix, controlnet_video],dim=2) - controlnet_tokens = controlnet_video.flatten(2).transpose(1, 2) #uni3c controlnet if uni3c_data is not None: @@ -3003,20 +2996,31 @@ class WanModel(torch.nn.Module): #controlnet if (controlnet is not None) and (b % controlnet["controlnet_stride"] == 0) and (b // controlnet["controlnet_stride"] < len(controlnet["controlnet_states"])): x[:, :self.original_seq_len] += controlnet["controlnet_states"][b // controlnet["controlnet_stride"]].to(x) * controlnet["controlnet_weight"] - # One-to-All controlnet + # One-to-All-Animation controlnet if onetoall_control_enabled: if prev_x is not None and (b - 1) < len(self.controlnet.blocks): + tqdm.write(f"Applying One-to-All ControlNet at block {b}") if b == 1: ctrl_in = prev_x + controlnet_tokens elif prev_control is not None: ctrl_in = prev_control + self.controlnet.blocks[b - 1].to(self.main_device) control_out = self.controlnet.blocks[b - 1](ctrl_in, e0, seq_lens, freqs, split_rope=False) + self.controlnet.blocks[b - 1].to(self.offload_device, non_blocking=self.use_non_blocking) prev_control = control_out control_out_proj = self.controlnet_zero[b - 1](control_out) x = x + control_out_proj * onetoall_control_strength - prev_x = x + if b < len(self.controlnet.blocks): # Store prev_x only while controlnet is active + prev_x = x + elif b == len(self.controlnet.blocks): # Controlnet done, free memory + prev_x = None + prev_control = None + if controlnet_tokens is not None: + del controlnet_tokens + controlnet_tokens = None + mm.soft_empty_cache() if lynx_ref_feature_extractor: return lynx_ref_buffer