1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00
Files
ComfyUI-WanVideoWrapper/nodes_model_loading.py
kijai 48fa904ad8 fp8 matmul for scaled models
Fp8 matmul (fp8_fast) doesn't seem feasible with unmerged LoRAs as you'd need to first upcast, then apply LoRA, then downcast back to fp8 and that is too slow. Direct adding in fp8 is also not possible since that's just not something fp8 dtypes support.
2025-08-09 10:17:11 +03:00

1561 lines
72 KiB
Python

import torch
import os, gc, uuid
from .utils import log, apply_lora
import numpy as np
from tqdm import tqdm
from .wanvideo.modules.model import WanModel
from .wanvideo.modules.t5 import T5EncoderModel
from .wanvideo.modules.clip import CLIPModel
from accelerate import init_empty_weights
from .utils import set_module_tensor_to_device
from .fp8_optimization import convert_linear_with_lora_and_scale
import folder_paths
import comfy.model_management as mm
from comfy.utils import load_torch_file, ProgressBar
import comfy.model_base
from comfy.sd import load_lora_for_models
script_directory = os.path.dirname(os.path.abspath(__file__))
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
try:
from server import PromptServer
except:
PromptServer = None
#from city96's gguf nodes
def update_folder_names_and_paths(key, targets=[]):
# check for existing key
base = folder_paths.folder_names_and_paths.get(key, ([], {}))
base = base[0] if isinstance(base[0], (list, set, tuple)) else []
# find base key & add w/ fallback, sanity check + warning
target = next((x for x in targets if x in folder_paths.folder_names_and_paths), targets[0])
orig, _ = folder_paths.folder_names_and_paths.get(target, ([], {}))
folder_paths.folder_names_and_paths[key] = (orig or base, {".gguf"})
if base and base != orig:
log.warning(f"Unknown file list already present on key {key}: {base}")
update_folder_names_and_paths("unet_gguf", ["diffusion_models", "unet"])
class WanVideoModel(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pipeline = {}
def __getitem__(self, k):
return self.pipeline[k]
def __setitem__(self, k, v):
self.pipeline[k] = v
try:
from comfy.latent_formats import Wan21, Wan22
latent_format = Wan21
except: #for backwards compatibility
log.warning("Wan21 latent format not found, update ComfyUI for better livepreview")
from comfy.latent_formats import HunyuanVideo
latent_format = HunyuanVideo
class WanVideoModelConfig:
def __init__(self, dtype, latent_format=latent_format):
self.unet_config = {}
self.unet_extra_config = {}
self.latent_format = latent_format
#self.latent_format.latent_channels = 16
self.manual_cast_dtype = dtype
self.sampling_settings = {"multiplier": 1.0}
self.memory_usage_factor = 2.0
self.unet_config["disable_unet_model_creation"] = True
def filter_state_dict_by_blocks(state_dict, blocks_mapping, layer_filter=[]):
filtered_dict = {}
if isinstance(layer_filter, str):
layer_filters = [layer_filter] if layer_filter else []
else:
# Filter out empty strings
layer_filters = [f for f in layer_filter if f] if layer_filter else []
#print("layer_filter: ", layer_filters)
for key in state_dict:
if not any(filter_str in key for filter_str in layer_filters):
if 'blocks.' in key:
block_pattern = key.split('diffusion_model.')[1].split('.', 2)[0:2]
block_key = f'{block_pattern[0]}.{block_pattern[1]}.'
if block_key in blocks_mapping:
filtered_dict[key] = state_dict[key]
else:
filtered_dict[key] = state_dict[key]
for key in filtered_dict:
print(key)
#from safetensors.torch import save_file
#save_file(filtered_dict, "filtered_state_dict_2.safetensors")
return filtered_dict
def standardize_lora_key_format(lora_sd):
new_sd = {}
for k, v in lora_sd.items():
# aitoolkit/lycoris format
if k.startswith("lycoris_blocks_"):
k = k.replace("lycoris_blocks_", "blocks.")
k = k.replace("_cross_attn_", ".cross_attn.")
k = k.replace("_self_attn_", ".self_attn.")
k = k.replace("_ffn_net_0_proj", ".ffn.0")
k = k.replace("_ffn_net_2", ".ffn.2")
k = k.replace("to_out_0", "o")
# Diffusers format
if k.startswith('transformer.'):
k = k.replace('transformer.', 'diffusion_model.')
if k.startswith('pipe.dit.'): #unianimate-dit/diffsynth
k = k.replace('pipe.dit.', 'diffusion_model.')
if k.startswith('blocks.'):
k = k.replace('blocks.', 'diffusion_model.blocks.')
k = k.replace('.default.', '.')
# Fun LoRA format
if k.startswith('lora_unet__'):
# Split into main path and weight type parts
parts = k.split('.')
main_part = parts[0] # e.g. lora_unet__blocks_0_cross_attn_k
weight_type = '.'.join(parts[1:]) if len(parts) > 1 else None # e.g. lora_down.weight
# Process the main part - convert from underscore to dot format
if 'blocks_' in main_part:
# Extract components
components = main_part[len('lora_unet__'):].split('_')
# Start with diffusion_model
new_key = "diffusion_model"
# Add blocks.N
if components[0] == 'blocks':
new_key += f".blocks.{components[1]}"
# Handle different module types
idx = 2
if idx < len(components):
if components[idx] == 'self' and idx+1 < len(components) and components[idx+1] == 'attn':
new_key += ".self_attn"
idx += 2
elif components[idx] == 'cross' and idx+1 < len(components) and components[idx+1] == 'attn':
new_key += ".cross_attn"
idx += 2
elif components[idx] == 'ffn':
new_key += ".ffn"
idx += 1
# Add the component (k, q, v, o) and handle img suffix
if idx < len(components):
component = components[idx]
idx += 1
# Check for img suffix
if idx < len(components) and components[idx] == 'img':
component += '_img'
idx += 1
new_key += f".{component}"
# Handle weight type - this is the critical fix
if weight_type:
if weight_type == 'alpha':
new_key += '.alpha'
elif weight_type == 'lora_down.weight' or weight_type == 'lora_down':
new_key += '.lora_A.weight'
elif weight_type == 'lora_up.weight' or weight_type == 'lora_up':
new_key += '.lora_B.weight'
else:
# Keep original weight type if not matching our patterns
new_key += f'.{weight_type}'
# Add .weight suffix if missing
if not new_key.endswith('.weight'):
new_key += '.weight'
k = new_key
else:
# For other lora_unet__ formats (head, embeddings, etc.)
new_key = main_part.replace('lora_unet__', 'diffusion_model.')
# Fix specific component naming patterns
new_key = new_key.replace('_self_attn', '.self_attn')
new_key = new_key.replace('_cross_attn', '.cross_attn')
new_key = new_key.replace('_ffn', '.ffn')
new_key = new_key.replace('blocks_', 'blocks.')
new_key = new_key.replace('head_head', 'head.head')
new_key = new_key.replace('img_emb', 'img_emb')
new_key = new_key.replace('text_embedding', 'text.embedding')
new_key = new_key.replace('time_embedding', 'time.embedding')
new_key = new_key.replace('time_projection', 'time.projection')
# Replace remaining underscores with dots, carefully
parts = new_key.split('.')
final_parts = []
for part in parts:
if part in ['img_emb', 'self_attn', 'cross_attn']:
final_parts.append(part) # Keep these intact
else:
final_parts.append(part.replace('_', '.'))
new_key = '.'.join(final_parts)
# Handle weight type
if weight_type:
if weight_type == 'alpha':
new_key += '.alpha'
elif weight_type == 'lora_down.weight' or weight_type == 'lora_down':
new_key += '.lora_A.weight'
elif weight_type == 'lora_up.weight' or weight_type == 'lora_up':
new_key += '.lora_B.weight'
else:
new_key += f'.{weight_type}'
if not new_key.endswith('.weight'):
new_key += '.weight'
k = new_key
# Handle special embedded components
special_components = {
'time.projection': 'time_projection',
'img.emb': 'img_emb',
'text.emb': 'text_emb',
'time.emb': 'time_emb',
}
for old, new in special_components.items():
if old in k:
k = k.replace(old, new)
# Fix diffusion.model -> diffusion_model
if k.startswith('diffusion.model.'):
k = k.replace('diffusion.model.', 'diffusion_model.')
# Finetrainer format
if '.attn1.' in k:
k = k.replace('.attn1.', '.cross_attn.')
k = k.replace('.to_k.', '.k.')
k = k.replace('.to_q.', '.q.')
k = k.replace('.to_v.', '.v.')
k = k.replace('.to_out.0.', '.o.')
elif '.attn2.' in k:
k = k.replace('.attn2.', '.cross_attn.')
k = k.replace('.to_k.', '.k.')
k = k.replace('.to_q.', '.q.')
k = k.replace('.to_v.', '.v.')
k = k.replace('.to_out.0.', '.o.')
if "img_attn.proj" in k:
k = k.replace("img_attn.proj", "img_attn_proj")
if "img_attn.qkv" in k:
k = k.replace("img_attn.qkv", "img_attn_qkv")
if "txt_attn.proj" in k:
k = k.replace("txt_attn.proj", "txt_attn_proj")
if "txt_attn.qkv" in k:
k = k.replace("txt_attn.qkv", "txt_attn_qkv")
new_sd[k] = v
return new_sd
class WanVideoBlockSwap:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"blocks_to_swap": ("INT", {"default": 20, "min": 0, "max": 40, "step": 1, "tooltip": "Number of transformer blocks to swap, the 14B model has 40, while the 1.3B model has 30 blocks"}),
"offload_img_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload img_emb to offload_device"}),
"offload_txt_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload time_emb to offload_device"}),
},
"optional": {
"use_non_blocking": ("BOOLEAN", {"default": False, "tooltip": "Use non-blocking memory transfer for offloading, reserves more RAM but is faster"}),
"vace_blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 15, "step": 1, "tooltip": "Number of VACE blocks to swap, the VACE model has 15 blocks"}),
},
}
RETURN_TYPES = ("BLOCKSWAPARGS",)
RETURN_NAMES = ("block_swap_args",)
FUNCTION = "setargs"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Settings for block swapping, reduces VRAM use by swapping blocks to CPU memory"
def setargs(self, **kwargs):
return (kwargs, )
class WanVideoVRAMManagement:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"offload_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Percentage of parameters to offload"}),
},
}
RETURN_TYPES = ("VRAM_MANAGEMENTARGS",)
RETURN_NAMES = ("vram_management_args",)
FUNCTION = "setargs"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"
def setargs(self, **kwargs):
return (kwargs, )
class WanVideoTorchCompileSettings:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
"compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only the transformer blocks, usually enough and can make compilation faster and less error prone"}),
},
"optional": {
"dynamo_recompile_limit": ("INT", {"default": 128, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.recompile_limit"}),
},
}
RETURN_TYPES = ("WANCOMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "set_args"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128):
compile_args = {
"backend": backend,
"fullgraph": fullgraph,
"mode": mode,
"dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
"dynamo_recompile_limit": dynamo_recompile_limit,
"compile_transformer_blocks_only": compile_transformer_blocks_only,
}
return (compile_args, )
class WanVideoLoraSelect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (folder_paths.get_filename_list("loras"),
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
"prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
"blocks":("SELECTEDBLOCKS", ),
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current one. No effect if merge_loras is False"}),
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("WANVIDLORA",)
RETURN_NAMES = ("lora", )
FUNCTION = "getlorapath"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras"
def getlorapath(self, lora, strength, unique_id, blocks={}, prev_lora=None, low_mem_load=False, merge_loras=True):
if not merge_loras:
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
loras_list = []
if not isinstance(strength, list):
strength = round(strength, 4)
if strength == 0.0:
if prev_lora is not None:
loras_list.extend(prev_lora)
return (loras_list,)
try:
lora_path = folder_paths.get_full_path("loras", lora)
except:
lora_path = lora
# Load metadata from the safetensors file
metadata = {}
try:
from safetensors.torch import safe_open
with safe_open(lora_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
except Exception as e:
print(f"Could not load metadata from {lora}: {e}")
if unique_id and PromptServer is not None:
try:
if metadata:
# Build table rows for metadata
metadata_rows = ""
for key, value in metadata.items():
# Format value - handle special cases
if isinstance(value, dict):
formatted_value = "<pre>" + "\n".join([f"{k}: {v}" for k, v in value.items()]) + "</pre>"
elif isinstance(value, (list, tuple)):
formatted_value = "<pre>" + "\n".join([str(item) for item in value]) + "</pre>"
else:
formatted_value = str(value)
metadata_rows += f"<tr><td><b>{key}</b></td><td>{formatted_value}</td></tr>"
PromptServer.instance.send_progress_text(
f"<details>"
f"<summary><b>Metadata</b></summary>"
f"<table border='0' cellpadding='3'>"
f"<tr><td colspan='2'><b>Metadata</b></td></tr>"
f"{metadata_rows}"
f"</table>"
f"</details>",
unique_id
)
except Exception as e:
print(f"Error displaying metadata: {e}")
pass
lora = {
"path": lora_path,
"strength": strength,
"name": lora.split(".")[0],
"blocks": blocks.get("selected_blocks", {}),
"layer_filter": blocks.get("layer_filter", ""),
"low_mem_load": low_mem_load,
"merge_loras": merge_loras,
}
if prev_lora is not None:
loras_list.extend(prev_lora)
loras_list.append(lora)
return (loras_list,)
class WanVideoLoraSelectMulti:
@classmethod
def INPUT_TYPES(s):
lora_files = folder_paths.get_filename_list("loras")
lora_files = ["none"] + lora_files # Add "none" as the first option
return {
"required": {
"lora_0": (lora_files, {"default": "none"}),
"strength_0": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
"lora_1": (lora_files, {"default": "none"}),
"strength_1": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
"lora_2": (lora_files, {"default": "none"}),
"strength_2": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
"lora_3": (lora_files, {"default": "none"}),
"strength_3": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
"lora_4": (lora_files, {"default": "none"}),
"strength_4": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
"prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
"blocks":("SELECTEDBLOCKS", ),
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading. No effect if merge_loras is False"}),
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
}
}
RETURN_TYPES = ("WANVIDLORA",)
RETURN_NAMES = ("lora", )
FUNCTION = "getlorapath"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras"
def getlorapath(self, lora_0, strength_0, lora_1, strength_1, lora_2, strength_2,
lora_3, strength_3, lora_4, strength_4, blocks={}, prev_lora=None,
low_mem_load=False, merge_loras=True):
if not merge_loras:
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
loras_list = list(prev_lora) if prev_lora else []
lora_inputs = [
(lora_0, strength_0),
(lora_1, strength_1),
(lora_2, strength_2),
(lora_3, strength_3),
(lora_4, strength_4)
]
for lora_name, strength in lora_inputs:
s = round(strength, 4) if not isinstance(strength, list) else strength
if not lora_name or lora_name == "none" or s == 0.0:
continue
loras_list.append({
"path": folder_paths.get_full_path("loras", lora_name),
"strength": s,
"name": lora_name.split(".")[0],
"blocks": blocks.get("selected_blocks", {}),
"layer_filter": blocks.get("layer_filter", ""),
"low_mem_load": low_mem_load,
"merge_loras": merge_loras,
})
return (loras_list,)
class WanVideoVACEModelSelect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vace_model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' VACE model to use when not using model that has it included"}),
},
}
RETURN_TYPES = ("VACEPATH",)
RETURN_NAMES = ("vace_model", )
FUNCTION = "getvacepath"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "VACE model to use when not using model that has it included, loaded from 'ComfyUI/models/diffusion_models'"
def getvacepath(self, vace_model):
vace_model = {
"path": folder_paths.get_full_path("diffusion_models", vace_model),
}
return (vace_model,)
class WanVideoLoraBlockEdit:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
arg_dict = {}
argument = ("BOOLEAN", {"default": True})
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
return {"required": arg_dict, "optional": {"layer_filter": ("STRING", {"default": "", "multiline": True})}}
RETURN_TYPES = ("SELECTEDBLOCKS", )
RETURN_NAMES = ("blocks", )
OUTPUT_TOOLTIPS = ("The modified lora model",)
FUNCTION = "select"
CATEGORY = "WanVideoWrapper"
def select(self, layer_filter=[], **kwargs):
selected_blocks = {k: v for k, v in kwargs.items() if v is True and isinstance(v, bool)}
print("Selected blocks LoRA: ", selected_blocks)
selected = {
"selected_blocks": selected_blocks,
"layer_filter": [x.strip() for x in layer_filter.split(",")]
}
return (selected,)
def model_lora_keys_unet(model, key_map={}):
sd = model.state_dict()
sdk = sd.keys()
for k in sdk:
k = k.replace("_orig_mod.", "")
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
return key_map
def add_patches(patcher, patches, strength_patch=1.0, strength_model=1.0):
with patcher.use_ejected():
p = set()
model_sd = patcher.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
# Check for key, or key with '._orig_mod' inserted after block number, in model_sd
key_in_sd = key in model_sd
key_orig_mod = None
if not key_in_sd:
# Try to insert '._orig_mod' after the block number if pattern matches
parts = key.split('.')
# Look for 'blocks', block number, then insert
try:
idx = parts.index('blocks')
if idx + 1 < len(parts):
# Only if the next part is a number
if parts[idx+1].isdigit():
new_parts = parts[:idx+2] + ['_orig_mod'] + parts[idx+2:]
key_orig_mod = '.'.join(new_parts)
except ValueError:
pass
key_orig_mod_in_sd = key_orig_mod is not None and key_orig_mod in model_sd
if key_in_sd or key_orig_mod_in_sd:
actual_key = key if key_in_sd else key_orig_mod
p.add(k)
current_patches = patcher.patches.get(actual_key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
patcher.patches[actual_key] = current_patches
patcher.patches_uuid = uuid.uuid4()
return list(p)
def load_lora_for_models_mod(model, lora, strength_model):
key_map = {}
if model is not None:
key_map = model_lora_keys_unet(model.model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
new_modelpatcher = model.clone()
k = add_patches(new_modelpatcher, loaded, strength_model)
k = set(k)
for x in loaded:
if (x not in k):
log.warning("NOT LOADED {}".format(x))
return (new_modelpatcher)
class WanVideoSetLoRAs:
@classmethod
def INPUT_TYPES(s):
return {
"required":
{
"model": ("WANVIDEOMODEL", ),
},
"optional": {
"lora": ("WANVIDLORA", ),
}
}
RETURN_TYPES = ("WANVIDEOMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "setlora"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
DESCRIPTION = "Sets the LoRA weights to be used directly in linear layers of the model, this does NOT merge LoRAs"
def setlora(self, model, lora=None):
if lora is None:
return (model,)
patcher = model.clone()
merge_loras = False
for l in lora:
merge_loras = l.get("merge_loras", True)
if merge_loras is True:
raise ValueError("Set LoRA node does not use low_mem_load and can't merge LoRAs, disable 'merge_loras' in the LoRA select node.")
patcher.model_options['transformer_options']["lora_scheduling_enabled"] = False
for l in lora:
log.info(f"Loading LoRA: {l['name']} with strength: {l['strength']}")
lora_path = l["path"]
lora_strength = l["strength"]
if isinstance(lora_strength, list):
if merge_loras:
raise ValueError("LoRA strength should be a single value when merge_loras=True")
patcher.model_options['transformer_options']["lora_scheduling_enabled"] = True
if lora_strength == 0:
log.warning(f"LoRA {lora_path} has strength 0, skipping...")
continue
lora_sd = load_torch_file(lora_path, safe_load=True)
if "dwpose_embedding.0.weight" in lora_sd: #unianimate
raise NotImplementedError("Unianimate LoRA patching is not implemented in this node.")
lora_sd = standardize_lora_key_format(lora_sd)
if l["blocks"]:
lora_sd = filter_state_dict_by_blocks(lora_sd, l["blocks"], l.get("layer_filter", []))
if "diffusion_model.patch_embedding.lora_A.weight" in lora_sd:
raise NotImplementedError("Control LoRA patching is not implemented in this node.")
patcher = load_lora_for_models_mod(patcher, lora_sd, lora_strength)
del lora_sd
if 'transformer_options' not in patcher.model_options:
patcher.model_options['transformer_options'] = {}
patcher.model_options['transformer_options']["patch_linear"] = True
return (patcher,)
#region Model loading
class WanVideoModelLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
"base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}),
"quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", "tooltip": "optional quantization method"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
},
"optional": {
"attention_mode": ([
"sdpa",
"flash_attn_2",
"flash_attn_3",
"sageattn",
"sageattn_3",
"flex_attention",
"radial_sage_attention",
], {"default": "sdpa"}),
"compile_args": ("WANCOMPILEARGS", ),
"block_swap_args": ("BLOCKSWAPARGS", ),
"lora": ("WANVIDLORA", {"default": None}),
"vram_management_args": ("VRAM_MANAGEMENTARGS", {"default": None, "tooltip": "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"}),
"vace_model": ("VACEPATH", {"default": None, "tooltip": "VACE model to use when not using model that has it included"}),
"fantasytalking_model": ("FANTASYTALKINGMODEL", {"default": None, "tooltip": "FantasyTalking model https://github.com/Fantasy-AMAP"}),
"multitalk_model": ("MULTITALKMODEL", {"default": None, "tooltip": "Multitalk model"}),
}
}
RETURN_TYPES = ("WANVIDEOMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
def loadmodel(self, model, base_precision, load_device, quantization,
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, vram_management_args=None, vace_model=None, fantasytalking_model=None, multitalk_model=None):
assert not (vram_management_args is not None and block_swap_args is not None), "Can't use both block_swap_args and vram_management_args at the same time"
lora_low_mem_load = merge_loras = False
if lora is not None:
for l in lora:
lora_low_mem_load = l.get("low_mem_load", False)
merge_loras = l.get("merge_loras", True)
transformer = None
mm.unload_all_models()
mm.cleanup_models()
mm.soft_empty_cache()
manual_offloading = True
if "sage" in attention_mode:
try:
from sageattention import sageattn
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")
gguf = False
if model.endswith(".gguf"):
if quantization != "disabled":
raise ValueError("Quantization should be disabled when loading GGUF models.")
quantization = "gguf"
gguf = True
if merge_loras is True:
raise ValueError("GGUF models do not support LoRA merging, please disable merge_loras in the LoRA select node.")
manual_offloading = True
transformer_load_device = device if load_device == "main_device" else offload_device
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision]
if base_precision == "fp16_fast":
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
torch.backends.cuda.matmul.allow_fp16_accumulation = True
else:
raise ValueError("torch.backends.cuda.matmul.allow_fp16_accumulation is not available in this version of torch, requires torch 2.7.0.dev2025 02 26 nightly minimum currently")
else:
try:
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
torch.backends.cuda.matmul.allow_fp16_accumulation = False
except:
pass
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model)
if not gguf:
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)
else:
from diffusers.models.model_loading_utils import load_gguf_checkpoint
sd = load_gguf_checkpoint(model_path)
if quantization == "disabled":
for k, v in sd.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.float8_e4m3fn:
quantization = "fp8_e4m3fn"
if "scaled_fp8" in sd:
quantization = "fp8_e4m3fn_scaled"
break
elif v.dtype == torch.float8_e5m2:
quantization = "fp8_e5m2"
if "scaled_fp8" in sd:
quantization = "fp8_e5m2_scaled"
break
if "scaled_fp8" in sd and "scaled" not in quantization:
raise ValueError("The model is a scaled fp8 model, please set quantization to '_scaled'")
if "vace_blocks.0.after_proj.weight" in sd and not "patch_embedding.weight" in sd:
raise ValueError("You are attempting to load a VACE module as a WanVideo model, instead you should use the vace_model input and matching T2V base model")
if vace_model is not None:
if gguf:
if not vace_model["path"].endswith(".gguf"):
raise ValueError("With GGUF main model the VACE module must also be a GGUF quantized, if the main model already has VACE included, you can disconnect the VACE module loader")
vace_sd = load_gguf_checkpoint(model_path)
else:
vace_sd = load_torch_file(vace_model["path"], device=transformer_load_device, safe_load=True)
sd.update(vace_sd)
first_key = next(iter(sd))
if first_key.startswith("model.diffusion_model."):
new_sd = {}
for key, value in sd.items():
new_key = key.replace("model.diffusion_model.", "", 1)
new_sd[new_key] = value
sd = new_sd
elif first_key.startswith("model."):
new_sd = {}
for key, value in sd.items():
new_key = key.replace("model.", "", 1)
new_sd[new_key] = value
sd = new_sd
if not "patch_embedding.weight" in sd:
raise ValueError("Invalid WanVideo model selected")
dim = sd["patch_embedding.weight"].shape[0]
in_features = sd["blocks.0.self_attn.k.weight"].shape[1]
out_features = sd["blocks.0.self_attn.k.weight"].shape[0]
in_channels = sd["patch_embedding.weight"].shape[1]
log.info(f"Detected model in_channels: {in_channels}")
ffn_dim = sd["blocks.0.ffn.0.bias"].shape[0]
ffn2_dim = sd["blocks.0.ffn.2.weight"].shape[1]
model_type = "t2v"
if not "text_embedding.0.weight" in sd:
model_type = "no_cross_attn" #minimaxremover
elif "model_type.Wan2_1-FLF2V-14B-720P" in sd or "img_emb.emb_pos" in sd or "flf2v" in model.lower():
model_type = "fl2v"
elif in_channels in [36, 48]:
if "blocks.0.cross_attn.k_img.weight" not in sd:
model_type = "t2v"
else:
model_type = "i2v"
elif in_channels == 16:
model_type = "t2v"
elif "control_adapter.conv.weight" in sd:
model_type = "t2v"
out_dim = 16
if dim == 5120: #14B
num_heads = 40
num_layers = 40
elif dim == 3072: #5B
num_heads = 24
num_layers = 30
out_dim = 48
model_type = "t2v" #5B no img crossattn
else: #1.3B
num_heads = 12
num_layers = 30
vace_layers, vace_in_dim = None, None
if "vace_blocks.0.after_proj.weight" in sd:
if in_channels != 16:
raise ValueError("VACE only works properly with T2V models.")
model_type = "t2v"
if dim == 5120:
vace_layers = [0, 5, 10, 15, 20, 25, 30, 35]
else:
vace_layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28]
vace_in_dim = 96
log.info(f"Model type: {model_type}, num_heads: {num_heads}, num_layers: {num_layers}")
teacache_coefficients_map = {
"1_3B": {
"e": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01],
"e0": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
},
"14B": {
"e": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404],
"e0": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
},
"i2v_480": {
"e": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01],
"e0": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
},
"i2v_720":{
"e": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
"e0": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
},
}
magcache_ratios_map = {
"1_3B": np.array([1.0]*2+[1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939]),
"14B": np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]),
"i2v_480": np.array([1.0]*2+[0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]),
"i2v_720": np.array([1.0]*2+[0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]),
}
model_variant = "14B" #default to this
if model_type == "i2v" or model_type == "fl2v":
if "480" in model or "fun" in model.lower() or "a2" in model.lower() or "540" in model: #just a guess for the Fun model for now...
model_variant = "i2v_480"
elif "720" in model:
model_variant = "i2v_720"
elif model_type == "t2v":
model_variant = "14B"
if dim == 1536:
model_variant = "1_3B"
if dim == 3072:
log.info(f"5B model detected, no Teacache or MagCache coefficients available, consider using EasyCache for this model")
log.info(f"Model variant detected: {model_variant}")
TRANSFORMER_CONFIG= {
"dim": dim,
"in_features": in_features,
"out_features": out_features,
"ffn_dim": ffn_dim,
"ffn2_dim": ffn2_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": out_dim,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
"attention_mode": attention_mode,
"rope_func": "comfy",
"main_device": device,
"offload_device": offload_device,
"teacache_coefficients": teacache_coefficients_map[model_variant],
"magcache_ratios": magcache_ratios_map[model_variant],
"vace_layers": vace_layers,
"vace_in_dim": vace_in_dim,
"inject_sample_info": True if "fps_embedding.weight" in sd else False,
"add_ref_conv": True if "ref_conv.weight" in sd else False,
"in_dim_ref_conv": sd["ref_conv.weight"].shape[1] if "ref_conv.weight" in sd else None,
"add_control_adapter": True if "control_adapter.conv.weight" in sd else False,
}
with init_empty_weights():
transformer = WanModel(**TRANSFORMER_CONFIG)
transformer.eval()
#ReCamMaster
if "blocks.0.cam_encoder.weight" in sd:
log.info("ReCamMaster model detected, patching model...")
import torch.nn as nn
for block in transformer.blocks:
block.cam_encoder = nn.Linear(12, dim)
block.projector = nn.Linear(dim, dim)
block.cam_encoder.weight.data.zero_()
block.cam_encoder.bias.data.zero_()
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
# FantasyTalking https://github.com/Fantasy-AMAP
if fantasytalking_model is not None:
log.info("FantasyTalking model detected, patching model...")
context_dim = fantasytalking_model["sd"]["proj_model.proj.weight"].shape[0]
import torch.nn as nn
for block in transformer.blocks:
block.cross_attn.k_proj = nn.Linear(context_dim, dim, bias=False)
block.cross_attn.v_proj = nn.Linear(context_dim, dim, bias=False)
sd.update(fantasytalking_model["sd"])
if multitalk_model is not None:
# init audio module
from .multitalk.multitalk import SingleStreamMultiAttention
from .wanvideo.modules.model import WanRMSNorm, WanLayerNorm
norm_input_visual = True #dunno what this is
for block in transformer.blocks:
block.audio_cross_attn = SingleStreamMultiAttention(
dim=dim,
encoder_hidden_states_dim=768,
num_heads=num_heads,
qk_norm=False,
qkv_bias=True,
eps=transformer.eps,
norm_layer=WanRMSNorm,
class_range=24,
class_interval=4,
attention_mode=attention_mode,
)
block.norm_x = WanLayerNorm(dim, transformer.eps, elementwise_affine=True) if norm_input_visual else nn.Identity()
log.info("MultiTalk model detected, patching model...")
sd.update(multitalk_model["sd"])
# Additional cond latents
if "add_conv_in.weight" in sd:
def zero_module(module):
for p in module.parameters():
torch.nn.init.zeros_(p)
return module
inner_dim = sd["add_conv_in.weight"].shape[0]
add_cond_in_dim = sd["add_conv_in.weight"].shape[1]
attn_cond_in_dim = sd["attn_conv_in.weight"].shape[1]
transformer.add_conv_in = torch.nn.Conv3d(add_cond_in_dim, inner_dim, kernel_size=transformer.patch_size, stride=transformer.patch_size)
transformer.add_proj = zero_module(torch.nn.Linear(inner_dim, inner_dim))
transformer.attn_conv_in = torch.nn.Conv3d(attn_cond_in_dim, inner_dim, kernel_size=transformer.patch_size, stride=transformer.patch_size)
latent_format=Wan22 if dim == 3072 else Wan21
comfy_model = WanVideoModel(
WanVideoModelConfig(base_dtype, latent_format=latent_format),
model_type=comfy.model_base.ModelType.FLOW,
device=device,
)
if not gguf:
scale_weights = {}
if "fp8" in quantization:
for k, v in sd.items():
if k.endswith(".scale_weight"):
scale_weights[k] = v
if "fp8_e4m3fn" in quantization:
dtype = torch.float8_e4m3fn
elif "fp8_e5m2" in quantization:
dtype = torch.float8_e5m2
else:
dtype = base_dtype
params_to_keep = {"norm", "bias", "time_in", "patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter", "add"}
if not lora_low_mem_load:
log.info("Using accelerate to load and assign model weights to device...")
param_count = sum(1 for _ in transformer.named_parameters())
pbar = ProgressBar(param_count)
cnt = 0
for name, param in tqdm(transformer.named_parameters(),
desc=f"Loading transformer parameters to {transformer_load_device}",
total=param_count,
leave=True):
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
dtype_to_use = dtype if sd[name].dtype == dtype else dtype_to_use
if "modulation" in name or "norm" in name or "bias" in name:
dtype_to_use = base_dtype
if "patch_embedding" in name:
dtype_to_use = torch.float32
set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name])
cnt += 1
if cnt % 100 == 0:
pbar.update(100)
#for name, param in transformer.named_parameters():
# print(name, param.dtype, param.device, param.shape)
pbar.update_absolute(param_count)
comfy_model.diffusion_model = transformer
comfy_model.load_device = transformer_load_device
patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, offload_device)
patcher.model.is_patched = False
control_lora = False
if lora is not None:
for l in lora:
log.info(f"Loading LoRA: {l['name']} with strength: {l['strength']}")
lora_path = l["path"]
lora_strength = l["strength"]
if isinstance(lora_strength, list):
if merge_loras:
raise ValueError("LoRA strength should be a single value when merge_loras=True")
transformer.lora_scheduling_enabled = True
if lora_strength == 0:
log.warning(f"LoRA {lora_path} has strength 0, skipping...")
continue
lora_sd = load_torch_file(lora_path, safe_load=True)
if "dwpose_embedding.0.weight" in lora_sd: #unianimate
from .unianimate.nodes import update_transformer
log.info("Unianimate LoRA detected, patching model...")
transformer = update_transformer(transformer, lora_sd)
lora_sd = standardize_lora_key_format(lora_sd)
if l["blocks"]:
lora_sd = filter_state_dict_by_blocks(lora_sd, l["blocks"], l.get("layer_filter", []))
#spacepxl's control LoRA patch
# for key in lora_sd.keys():
# print(key)
if "diffusion_model.patch_embedding.lora_A.weight" in lora_sd:
log.info("Control-LoRA detected, patching model...")
if not merge_loras:
log.warning("Control-LoRA patching is only supported with merge_loras=True, setting it to True")
merge_loras = True
control_lora = True
in_cls = transformer.patch_embedding.__class__ # nn.Conv3d
old_in_dim = transformer.in_dim # 16
new_in_dim = lora_sd["diffusion_model.patch_embedding.lora_A.weight"].shape[1]
assert new_in_dim == 32
new_in = in_cls(
new_in_dim,
transformer.patch_embedding.out_channels,
transformer.patch_embedding.kernel_size,
transformer.patch_embedding.stride,
transformer.patch_embedding.padding,
).to(device=device, dtype=torch.float32)
new_in.weight.zero_()
new_in.bias.zero_()
new_in.weight[:, :old_in_dim].copy_(transformer.patch_embedding.weight)
new_in.bias.copy_(transformer.patch_embedding.bias)
transformer.patch_embedding = new_in
transformer.expanded_patch_embedding = new_in
patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0)
del lora_sd
if not gguf and merge_loras:
log.info("Patching LoRA to the model...")
patcher = apply_lora(
patcher, device, transformer_load_device,
params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd,
low_mem_load=lora_low_mem_load, control_lora=control_lora, scale_weights=scale_weights)
scale_weights.clear()
patcher.patches.clear()
if gguf:
#from diffusers.quantizers.gguf.utils import _replace_with_gguf_linear, GGUFParameter
from .gguf.gguf import _replace_with_gguf_linear, GGUFParameter
log.info("Using GGUF to load and assign model weights to device...")
param_count = sum(1 for _ in transformer.named_parameters())
out_features = sd["blocks.0.self_attn.k.weight"].shape[1]
patcher.model.diffusion_model = _replace_with_gguf_linear(patcher.model.diffusion_model, base_dtype, sd, patches=patcher.patches)
pbar = ProgressBar(param_count)
cnt = 0
for name, param in tqdm(patcher.model.diffusion_model.named_parameters(),
desc=f"Loading transformer parameters to {transformer_load_device}",
total=param_count,
leave=True):
#print(name, param.dtype, param.device, param.shape)
if isinstance(param, GGUFParameter):
dtype_to_use = torch.uint8
elif "patch_embedding" in name:
dtype_to_use = torch.float32
else:
dtype_to_use = base_dtype
set_module_tensor_to_device(patcher.model.diffusion_model, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name])
cnt += 1
if cnt % 100 == 0:
pbar.update(100)
#for name, param in transformer.named_parameters():
# print(name, param.dtype, param.device, param.shape)
#patcher.load(device, full_load=True)
pbar.update_absolute(param_count)
patcher.model.is_patched = True
patch_linear = (True if "scaled" in quantization or not merge_loras else False)
if "fast" in quantization:
if lora is not None and not merge_loras:
raise NotImplementedError("fp8_fast is not supported with unmerged LoRAs")
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, base_dtype, params_to_keep, scale_weight_keys=scale_weights)
patch_linear = False
del sd
if multitalk_model is not None:
transformer.audio_proj = multitalk_model["proj_model"]
if vram_management_args is not None:
if gguf:
raise ValueError("GGUF models don't support vram management")
from .diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from .wanvideo.modules.model import WanLayerNorm, WanRMSNorm
total_params_in_model = sum(p.numel() for p in patcher.model.diffusion_model.parameters())
log.info(f"Total number of parameters in the loaded model: {total_params_in_model}")
offload_percent = vram_management_args["offload_percent"]
offload_params = int(total_params_in_model * offload_percent)
params_to_keep = total_params_in_model - offload_params
log.info(f"Selected params to offload: {offload_params}")
enable_vram_management(
patcher.model.diffusion_model,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
WanLayerNorm: AutoWrappedModule,
WanRMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device=offload_device,
onload_dtype=dtype,
onload_device=device,
computation_dtype=base_dtype,
computation_device=device,
),
max_num_param=params_to_keep,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device=offload_device,
onload_dtype=dtype,
onload_device=offload_device,
computation_dtype=base_dtype,
computation_device=device,
),
compile_args = compile_args,
)
if load_device == "offload_device" and patcher.model.diffusion_model.device != offload_device:
log.info(f"Moving diffusion model from {patcher.model.diffusion_model.device} to {offload_device}")
patcher.model.diffusion_model.to(offload_device)
gc.collect()
mm.soft_empty_cache()
patcher.model["dtype"] = base_dtype
patcher.model["base_path"] = model_path
patcher.model["model_name"] = model
patcher.model["manual_offloading"] = manual_offloading
patcher.model["quantization"] = quantization
patcher.model["auto_cpu_offload"] = True if vram_management_args is not None else False
patcher.model["control_lora"] = control_lora
patcher.model["compile_args"] = compile_args
patcher.model["gguf"] = gguf
patcher.model["fp8_matmul"] = "fast" in quantization
patcher.model["scale_weights"] = scale_weights
if 'transformer_options' not in patcher.model_options:
patcher.model_options['transformer_options'] = {}
patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args
patcher.model_options["transformer_options"]["patch_linear"] = patch_linear
patcher.model_options["transformer_options"]["merge_loras"] = merge_loras
for model in mm.current_loaded_models:
if model._model() == patcher:
mm.current_loaded_models.remove(model)
return (patcher,)
# class WanVideoSaveModel:
# @classmethod
# def INPUT_TYPES(s):
# return {
# "required": {
# "model": ("WANVIDEOMODEL", {"tooltip": "WANVideo model to save"}),
# "output_path": ("STRING", {"default": "", "multiline": False, "tooltip": "Path to save the model"}),
# },
# }
# RETURN_TYPES = ()
# FUNCTION = "savemodel"
# CATEGORY = "WanVideoWrapper"
# DESCRIPTION = "Saves the model including merged LoRAs and quantization to diffusion_models/WanVideoWrapperSavedModels"
# OUTPUT_NODE = True
# def savemodel(self, model, output_path):
# from safetensors.torch import save_file
# model_sd = model.model.diffusion_model.state_dict()
# for k in model_sd.keys():
# print("key:", k, "shape:", model_sd[k].shape, "dtype:", model_sd[k].dtype, "device:", model_sd[k].device)
# model_sd
# model_name = os.path.basename(model.model["model_name"])
# if not output_path:
# output_path = os.path.join(folder_paths.models_dir, "diffusion_models", "WanVideoWrapperSavedModels", "saved_" + model_name)
# else:
# output_path = os.path.join(output_path, model_name)
# log.info(f"Saving model to {output_path}")
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
# save_file(model_sd, output_path)
# return ()
#region load VAE
class WanVideoVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"}
),
"compile_args": ("WANCOMPILEARGS", ),
}
}
RETURN_TYPES = ("WANVAE",)
RETURN_NAMES = ("vae", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Loads Wan VAE model from 'ComfyUI/models/vae'"
def loadmodel(self, model_name, precision, compile_args=None):
from .wanvideo.wan_video_vae import WanVideoVAE, WanVideoVAE38
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
#with open(os.path.join(script_directory, 'configs', 'hy_vae_config.json')) as f:
# vae_config = json.load(f)
model_path = folder_paths.get_full_path("vae", model_name)
vae_sd = load_torch_file(model_path, safe_load=True)
has_model_prefix = any(k.startswith("model.") for k in vae_sd.keys())
if not has_model_prefix:
vae_sd = {f"model.{k}": v for k, v in vae_sd.items()}
if vae_sd["model.conv2.weight"].shape[0] == 16:
vae = WanVideoVAE(dtype=dtype)
elif vae_sd["model.conv2.weight"].shape[0] == 48:
vae = WanVideoVAE38(dtype=dtype)
vae.load_state_dict(vae_sd)
vae.eval()
vae.to(device = offload_device, dtype = dtype)
if compile_args is not None:
vae.model.decoder = torch.compile(vae.model.decoder, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
return (vae,)
class WanVideoTinyVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae_approx"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae_approx'"}),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"], {"default": "fp16"}),
"parallel": ("BOOLEAN", {"default": False, "tooltip": "uses more memory but is faster"}),
}
}
RETURN_TYPES = ("WANVAE",)
RETURN_NAMES = ("vae", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Loads Wan VAE model from 'ComfyUI/models/vae_approx'"
def loadmodel(self, model_name, precision, parallel=False):
from .taehv import TAEHV
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path("vae_approx", model_name)
vae_sd = load_torch_file(model_path, safe_load=True)
vae = TAEHV(vae_sd, parallel=parallel)
vae.to(device = offload_device, dtype = dtype)
return (vae,)
class LoadWanVideoT5TextEncoder:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}),
"precision": (["fp32", "bf16"],
{"default": "bf16"}
),
},
"optional": {
"load_device": (["main_device", "offload_device"], {"default": "offload_device"}),
"quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}),
}
}
RETURN_TYPES = ("WANTEXTENCODER",)
RETURN_NAMES = ("wan_t5_model", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Loads Wan text_encoder model from 'ComfyUI/models/LLM'"
def loadmodel(self, model_name, precision, load_device="offload_device", quantization="disabled"):
text_encoder_load_device = device if load_device == "main_device" else offload_device
tokenizer_path = os.path.join(script_directory, "configs", "T5_tokenizer")
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path("text_encoders", model_name)
sd = load_torch_file(model_path, safe_load=True)
if quantization == "disabled":
for k, v in sd.items():
if isinstance(v, torch.Tensor):
if v.dtype == torch.float8_e4m3fn:
quantization = "fp8_e4m3fn"
break
if "token_embedding.weight" not in sd and "shared.weight" not in sd:
raise ValueError("Invalid T5 text encoder model, this node expects the 'umt5-xxl' model")
if "scaled_fp8" in sd:
raise ValueError("Invalid T5 text encoder model, fp8 scaled is not supported by this node")
# Convert state dict keys from T5 format to the expected format
if "shared.weight" in sd:
log.info("Converting T5 text encoder model to the expected format...")
converted_sd = {}
for key, value in sd.items():
# Handle encoder block patterns
if key.startswith('encoder.block.'):
parts = key.split('.')
block_num = parts[2]
# Self-attention components
if 'layer.0.SelfAttention' in key:
if key.endswith('.k.weight'):
new_key = f"blocks.{block_num}.attn.k.weight"
elif key.endswith('.o.weight'):
new_key = f"blocks.{block_num}.attn.o.weight"
elif key.endswith('.q.weight'):
new_key = f"blocks.{block_num}.attn.q.weight"
elif key.endswith('.v.weight'):
new_key = f"blocks.{block_num}.attn.v.weight"
elif 'relative_attention_bias' in key:
new_key = f"blocks.{block_num}.pos_embedding.embedding.weight"
else:
new_key = key
# Layer norms
elif 'layer.0.layer_norm' in key:
new_key = f"blocks.{block_num}.norm1.weight"
elif 'layer.1.layer_norm' in key:
new_key = f"blocks.{block_num}.norm2.weight"
# Feed-forward components
elif 'layer.1.DenseReluDense' in key:
if 'wi_0' in key:
new_key = f"blocks.{block_num}.ffn.gate.0.weight"
elif 'wi_1' in key:
new_key = f"blocks.{block_num}.ffn.fc1.weight"
elif 'wo' in key:
new_key = f"blocks.{block_num}.ffn.fc2.weight"
else:
new_key = key
else:
new_key = key
elif key == "shared.weight":
new_key = "token_embedding.weight"
elif key == "encoder.final_layer_norm.weight":
new_key = "norm.weight"
else:
new_key = key
converted_sd[new_key] = value
sd = converted_sd
T5_text_encoder = T5EncoderModel(
text_len=512,
dtype=dtype,
device=text_encoder_load_device,
state_dict=sd,
tokenizer_path=tokenizer_path,
quantization=quantization
)
text_encoder = {
"model": T5_text_encoder,
"dtype": dtype,
"name": model_name,
}
return (text_encoder,)
class LoadWanVideoClipTextEncoder:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("clip_vision") + folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/clip_vision'"}),
"precision": (["fp16", "fp32", "bf16"],
{"default": "fp16"}
),
},
"optional": {
"load_device": (["main_device", "offload_device"], {"default": "offload_device"}),
}
}
RETURN_TYPES = ("CLIP_VISION",)
RETURN_NAMES = ("wan_clip_vision", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Loads Wan clip_vision model from 'ComfyUI/models/clip_vision'"
def loadmodel(self, model_name, precision, load_device="offload_device"):
text_encoder_load_device = device if load_device == "main_device" else offload_device
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path("clip_vision", model_name)
# We also support legacy setups where the model is in the text_encoders folder
if model_path is None:
model_path = folder_paths.get_full_path("text_encoders", model_name)
sd = load_torch_file(model_path, safe_load=True)
if "log_scale" not in sd:
raise ValueError("Invalid CLIP model, this node expectes the 'open-clip-xlm-roberta-large-vit-huge-14' model")
clip_model = CLIPModel(dtype=dtype, device=device, state_dict=sd)
clip_model.model.to(text_encoder_load_device)
del sd
return (clip_model,)
NODE_CLASS_MAPPINGS = {
"WanVideoModelLoader": WanVideoModelLoader,
"WanVideoVAELoader": WanVideoVAELoader,
"WanVideoLoraSelect": WanVideoLoraSelect,
"WanVideoSetLoRAs": WanVideoSetLoRAs,
"WanVideoLoraBlockEdit": WanVideoLoraBlockEdit,
"WanVideoTinyVAELoader": WanVideoTinyVAELoader,
"WanVideoVACEModelSelect": WanVideoVACEModelSelect,
"WanVideoLoraSelectMulti": WanVideoLoraSelectMulti,
"WanVideoBlockSwap": WanVideoBlockSwap,
"WanVideoVRAMManagement": WanVideoVRAMManagement,
"WanVideoTorchCompileSettings": WanVideoTorchCompileSettings,
"LoadWanVideoT5TextEncoder": LoadWanVideoT5TextEncoder,
"LoadWanVideoClipTextEncoder": LoadWanVideoClipTextEncoder,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoModelLoader": "WanVideo Model Loader",
"WanVideoVAELoader": "WanVideo VAE Loader",
"WanVideoLoraSelect": "WanVideo Lora Select",
"WanVideoSetLoRAs": "WanVideo Set LoRAs",
"WanVideoLoraBlockEdit": "WanVideo Lora Block Edit",
"WanVideoTinyVAELoader": "WanVideo Tiny VAE Loader",
"WanVideoVACEModelSelect": "WanVideo VACE Module Select",
"WanVideoLoraSelectMulti": "WanVideo Lora Select Multi",
"WanVideoBlockSwap": "WanVideo Block Swap",
"WanVideoVRAMManagement": "WanVideo VRAM Management",
"WanVideoTorchCompileSettings": "WanVideo Torch Compile Settings",
"LoadWanVideoT5TextEncoder": "WanVideo T5 Text Encoder Loader",
"LoadWanVideoClipTextEncoder": "WanVideo CLIP Text Encoder Loader",
}