You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
commit 73dd1a06d33953912f5dd684f168028b14e42a36 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Oct 13 19:47:38 2025 +0300 cleanup commit 39bc2cecf493e2eb176b55e8841d933f0da1ec39 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Oct 13 19:24:20 2025 +0300 Allow scheduling ovi cfg commit 2c153c5f324dbd59670ad9c51a7995459504a3cd Merge: dba766732eb6b4Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Oct 13 17:48:20 2025 +0300 Merge branch 'main' into ovi commit dba76674c71af7bf94c82834a0b0e40d94043c99 Merge: 0f11a435a0456eAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Oct 12 22:45:43 2025 +0300 Merge branch 'main' into ovi commit 0f11a439622799ad8070f8a2b8cc8e6a041b761d Merge: 0999f50e2d8c9bAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Oct 11 07:48:06 2025 +0300 Merge branch 'main' into ovi commit 0999f50cfe025290cd7ce88a8dd1acff0b38d9bd Merge: d45df1ff1d1c83Author: kijai <40791699+kijai@users.noreply.github.com> Date: Fri Oct 10 22:16:09 2025 +0300 Merge branch 'main' into ovi commit d45df1fb5b7c629b15eabc197357d62bdc232aaf Author: kijai <40791699+kijai@users.noreply.github.com> Date: Thu Oct 9 20:21:37 2025 +0300 Remove dependency for librosa commit d8e7533fdf7eab1d2489c3e025a908c02d997444 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Thu Oct 9 19:57:28 2025 +0300 Remove omegaconf dependency commit f4e27ff018e98cb5b09655dceda399baea36b240 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Thu Oct 9 19:31:06 2025 +0300 Fix VACE commit 35d3df39294831e5e7568b6f7e16d2ecf2d790a0 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Thu Oct 9 00:26:40 2025 +0300 small update commit 96f8ea1d26869ab7e49e12a07f19d5d5a2023253 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 22:32:57 2025 +0300 Create wanvideo_2_2_5B_ovi_testing.json commit a2511be73b9da7019fd21aeb0b521af941c09150 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 22:32:54 2025 +0300 Update nodes_sampler.py commit d3688b8db71452ea1f7c9a2bc0216441d524e56c Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 21:43:02 2025 +0300 Allow EasyCache to work with ovi commit 586d9148a0306ef5d30e9a971a9c3be4cd3ecc97 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 19:09:06 2025 +0300 Update model.py commit 61eedd2839decdb7d4c2ddd5f1310fdaf49d36ad Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 19:09:02 2025 +0300 I2V fix commit a97fcb1b9ae9fb7bbfdf668c24816e014a1b58d1 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 17:57:28 2025 +0300 Add nodes to set audio latent size commit d41e42a697f3d561dabbc22566f633b5f1bbd952 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 16:42:04 2025 +0300 Support loading mmaudio vae from .safetensors commit 1b0e28ec41e3c97fe1f2f057fef9b9bbcb87bca7 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 16:19:53 2025 +0300 Update nodes_sampler.py commit fbd18f45fe85ede8edcb5aebaea7ceb5b6eab5a2 Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 10:16:44 2025 +0300 Fixes for other workflows commit b06993b637198f7fad92208f3b3dc9a7d7f57c7f Author: kijai <40791699+kijai@users.noreply.github.com> Date: Wed Oct 8 09:46:27 2025 +0300 initial commit T2V works
169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# This work is licensed under a Creative Commons
|
|
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
|
# You should have received a copy of the license along with this
|
|
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
|
"""Improved diffusion model architecture proposed in the paper
|
|
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Variant of constant() that inherits dtype and device from the given
|
|
# reference tensor by default.
|
|
|
|
_constant_cache = dict()
|
|
|
|
|
|
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
|
value = np.asarray(value)
|
|
if shape is not None:
|
|
shape = tuple(shape)
|
|
if dtype is None:
|
|
dtype = torch.get_default_dtype()
|
|
if device is None:
|
|
device = torch.device('cpu')
|
|
if memory_format is None:
|
|
memory_format = torch.contiguous_format
|
|
|
|
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
|
tensor = _constant_cache.get(key, None)
|
|
if tensor is None:
|
|
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
|
if shape is not None:
|
|
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
|
tensor = tensor.contiguous(memory_format=memory_format)
|
|
_constant_cache[key] = tensor
|
|
return tensor
|
|
|
|
|
|
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
|
if dtype is None:
|
|
dtype = ref.dtype
|
|
if device is None:
|
|
device = ref.device
|
|
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Normalize given tensor to unit magnitude with respect to the given
|
|
# dimensions. Default = all dimensions except the first.
|
|
|
|
|
|
def normalize(x, dim=None, eps=1e-4):
|
|
if dim is None:
|
|
dim = list(range(1, x.ndim))
|
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
|
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
|
return x / norm.to(x.dtype)
|
|
|
|
|
|
class Normalize(torch.nn.Module):
|
|
|
|
def __init__(self, dim=None, eps=1e-4):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return normalize(x, dim=self.dim, eps=self.eps)
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Upsample or downsample the given tensor with the given filter,
|
|
# or keep it as is.
|
|
|
|
|
|
def resample(x, f=[1, 1], mode='keep'):
|
|
if mode == 'keep':
|
|
return x
|
|
f = np.float32(f)
|
|
assert f.ndim == 1 and len(f) % 2 == 0
|
|
pad = (len(f) - 1) // 2
|
|
f = f / f.sum()
|
|
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
|
f = const_like(x, f)
|
|
c = x.shape[1]
|
|
if mode == 'down':
|
|
return torch.nn.functional.conv2d(x,
|
|
f.tile([c, 1, 1, 1]),
|
|
groups=c,
|
|
stride=2,
|
|
padding=(pad, ))
|
|
assert mode == 'up'
|
|
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
|
|
groups=c,
|
|
stride=2,
|
|
padding=(pad, ))
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Magnitude-preserving SiLU (Equation 81).
|
|
|
|
|
|
def mp_silu(x):
|
|
return torch.nn.functional.silu(x) / 0.596
|
|
|
|
|
|
class MPSiLU(torch.nn.Module):
|
|
|
|
def forward(self, x):
|
|
return mp_silu(x)
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Magnitude-preserving sum (Equation 88).
|
|
|
|
|
|
def mp_sum(a, b, t=0.5):
|
|
return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Magnitude-preserving concatenation (Equation 103).
|
|
|
|
|
|
def mp_cat(a, b, dim=1, t=0.5):
|
|
Na = a.shape[dim]
|
|
Nb = b.shape[dim]
|
|
C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
|
|
wa = C / np.sqrt(Na) * (1 - t)
|
|
wb = C / np.sqrt(Nb) * t
|
|
return torch.cat([wa * a, wb * b], dim=dim)
|
|
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
|
# with force weight normalization (Equation 66).
|
|
|
|
|
|
class MPConv1D(torch.nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
|
|
|
|
self.weight_norm_removed = False
|
|
|
|
def forward(self, x, gain=1):
|
|
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
|
|
|
|
w = self.weight * gain
|
|
if w.ndim == 2:
|
|
return x @ w.t()
|
|
assert w.ndim == 3
|
|
return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
|
|
|
|
def remove_weight_norm(self):
|
|
w = self.weight.to(torch.float32)
|
|
w = normalize(w) # traditional weight normalization
|
|
w = w / np.sqrt(w[0].numel())
|
|
w = w.to(self.weight.dtype)
|
|
self.weight.data.copy_(w)
|
|
|
|
self.weight_norm_removed = True
|
|
return self
|