You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
167 lines
7.4 KiB
Python
167 lines
7.4 KiB
Python
import importlib.metadata
|
|
import torch
|
|
import logging
|
|
from tqdm import tqdm
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
log = logging.getLogger(__name__)
|
|
|
|
from accelerate.utils import set_module_tensor_to_device
|
|
def check_diffusers_version():
|
|
try:
|
|
version = importlib.metadata.version('diffusers')
|
|
required_version = '0.31.0'
|
|
if version < required_version:
|
|
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
|
|
except importlib.metadata.PackageNotFoundError:
|
|
raise AssertionError("diffusers is not installed.")
|
|
|
|
def print_memory(device):
|
|
memory = torch.cuda.memory_allocated(device) / 1024**3
|
|
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
|
|
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
|
log.info(f"Allocated memory: {memory=:.3f} GB")
|
|
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
|
|
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
|
|
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
|
|
#log.info(f"Memory Summary:\n{memory_summary}")
|
|
|
|
def get_module_memory_mb(module):
|
|
memory = 0
|
|
for param in module.parameters():
|
|
if param.data is not None:
|
|
memory += param.nelement() * param.element_size()
|
|
return memory / (1024 * 1024) # Convert to MB
|
|
|
|
def apply_lora(model, device_to, transformer_load_device, params_to_keep=None, dtype=None, base_dtype=None, state_dict=None, low_mem_load=False):
|
|
to_load = []
|
|
for n, m in model.model.named_modules():
|
|
params = []
|
|
skip = False
|
|
for name, param in m.named_parameters(recurse=False):
|
|
params.append(name)
|
|
for name, param in m.named_parameters(recurse=True):
|
|
if name not in params:
|
|
skip = True # skip random weights in non leaf modules
|
|
break
|
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
|
to_load.append((n, m, params))
|
|
|
|
to_load.sort(reverse=True)
|
|
for x in tqdm(to_load, desc="Loading model and applying LoRA weights:", leave=True):
|
|
name = x[0]
|
|
m = x[1]
|
|
params = x[2]
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
if m.comfy_patched_weights == True:
|
|
continue
|
|
for param in params:
|
|
name = name.replace("._orig_mod.", ".") # torch compiled modules have this prefix
|
|
if low_mem_load:
|
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
|
if "modulation" in name:
|
|
dtype_to_use = torch.float32
|
|
if name.startswith("diffusion_model."):
|
|
name_no_prefix = name[len("diffusion_model."):]
|
|
key = "{}.{}".format(name_no_prefix, param)
|
|
set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[key])
|
|
model.patch_weight_to_device("{}.{}".format(name, param), device_to=device_to)
|
|
if low_mem_load:
|
|
set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=model.model.diffusion_model.state_dict()[key])
|
|
|
|
m.comfy_patched_weights = True
|
|
|
|
model.current_weight_patches_uuid = model.patches_uuid
|
|
if low_mem_load:
|
|
for name, param in model.model.diffusion_model.named_parameters():
|
|
if param.device != transformer_load_device:
|
|
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
|
if "modulation" in name:
|
|
dtype_to_use = torch.float32
|
|
set_module_tensor_to_device(model.model.diffusion_model, name, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[name])
|
|
return model
|
|
|
|
|
|
# from https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/9d076a3df0d2763cef5510ec5ab807f6632c39f5/utils.py#L181
|
|
def split_tiles(embeds, num_split):
|
|
_, H, W, _ = embeds.shape
|
|
out = []
|
|
for x in embeds:
|
|
x = x.unsqueeze(0)
|
|
h, w = H // num_split, W // num_split
|
|
x_split = torch.cat([x[:, i*h:(i+1)*h, j*w:(j+1)*w, :] for i in range(num_split) for j in range(num_split)], dim=0)
|
|
out.append(x_split)
|
|
|
|
x_split = torch.stack(out, dim=0)
|
|
|
|
return x_split
|
|
|
|
def merge_hiddenstates(x, tiles):
|
|
chunk_size = tiles*tiles
|
|
x = x.split(chunk_size)
|
|
|
|
out = []
|
|
for embeds in x:
|
|
num_tiles = embeds.shape[0]
|
|
tile_size = int((embeds.shape[1]-1) ** 0.5)
|
|
grid_size = int(num_tiles ** 0.5)
|
|
|
|
# Extract class tokens
|
|
class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]]
|
|
avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]]
|
|
|
|
patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]]
|
|
reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1])
|
|
|
|
merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1)
|
|
for i in range(grid_size)], dim=0)
|
|
|
|
merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]]
|
|
|
|
# Pool to original size
|
|
pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1)
|
|
flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1])
|
|
|
|
# Add back the class token
|
|
with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape
|
|
out.append(with_class)
|
|
|
|
out = torch.cat(out, dim=0)
|
|
|
|
return out
|
|
|
|
from comfy.clip_vision import clip_preprocess, ClipVisionModel
|
|
|
|
def clip_encode_image_tiled(clip_vision, image, tiles=1, ratio=1.0):
|
|
embeds = encode_image_(clip_vision, image)
|
|
tiles = min(tiles, 16)
|
|
|
|
if tiles > 1:
|
|
# split in tiles
|
|
image_split = split_tiles(image, tiles)
|
|
|
|
# get the embeds for each tile
|
|
embeds_split = {}
|
|
for i in image_split:
|
|
encoded = encode_image_(clip_vision, i)
|
|
if not hasattr(embeds_split, "last_hidden_state"):
|
|
embeds_split["last_hidden_state"] = encoded
|
|
else:
|
|
embeds_split["last_hidden_state"] = torch.cat(embeds_split["last_hidden_state"], encoded, dim=0)
|
|
|
|
embeds_split['last_hidden_state'] = merge_hiddenstates(embeds_split['last_hidden_state'], tiles)
|
|
|
|
if embeds.shape[0] > 1: # if we have more than one image we need to average the embeddings for consistency
|
|
embeds = embeds * ratio + embeds_split['last_hidden_state']*(1-ratio)
|
|
else: # otherwise we can concatenate them, they can be averaged later
|
|
embeds = torch.cat([embeds * ratio, embeds_split['last_hidden_state']])
|
|
|
|
return embeds
|
|
|
|
def encode_image_(clip_vision, image):
|
|
if isinstance(clip_vision, ClipVisionModel):
|
|
out = clip_vision.encode_image(image).last_hidden_state
|
|
else:
|
|
pixel_values = clip_preprocess(image, size=224, crop=True).float()
|
|
out = clip_vision.visual(pixel_values)
|
|
|
|
return out |