You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Merge branch 'main' into longvie2
This commit is contained in:
@@ -37,15 +37,16 @@ class WanVideoLongCatAvatarExtendEmbeds(io.ComfyNode):
|
||||
new_audio_embed = audio_embeds.copy()
|
||||
|
||||
audio_features = torch.stack(new_audio_embed["audio_features"])
|
||||
num_audio_features = audio_features.shape[1]
|
||||
if audio_features.shape[1] < frames_processed + num_frames:
|
||||
deficit = frames_processed + num_frames - audio_features.shape[1]
|
||||
if if_not_enough_audio == "pad_with_start":
|
||||
pad = audio_features[:, :1].repeat(1, deficit, 1, 1, 1)
|
||||
pad = audio_features[:, :1].repeat(1, deficit, 1, 1)
|
||||
audio_features = torch.cat([audio_features, pad], dim=1)
|
||||
elif if_not_enough_audio == "mirror_from_end":
|
||||
to_add = audio_features[:, -deficit:, :].flip(dims=[1])
|
||||
audio_features = torch.cat([audio_features, to_add], dim=1)
|
||||
log.info(f"Not enough audio features, extended from {new_audio_embed['audio_features'].shape[1]} to {audio_features.shape[1]} frames.")
|
||||
log.warning(f"Not enough audio features, padded with strategy '{if_not_enough_audio}' from {num_audio_features} to {audio_features.shape[1]} frames")
|
||||
|
||||
ref_target_masks = new_audio_embed.get("ref_target_masks", None)
|
||||
if ref_target_masks is not None:
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -1,780 +0,0 @@
|
||||
{
|
||||
"id": "8b7a9a57-2303-4ef5-9fc2-bf41713bd1fc",
|
||||
"revision": 0,
|
||||
"last_node_id": 46,
|
||||
"last_link_id": 58,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 33,
|
||||
"type": "Note",
|
||||
"pos": [
|
||||
227.3764190673828,
|
||||
-205.28524780273438
|
||||
],
|
||||
"size": [
|
||||
351.70458984375,
|
||||
88
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"properties": {},
|
||||
"widgets_values": [
|
||||
"Models:\nhttps://huggingface.co/Kijai/WanVideo_comfy/tree/main"
|
||||
],
|
||||
"color": "#432",
|
||||
"bgcolor": "#653"
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "LoadWanVideoT5TextEncoder",
|
||||
"pos": [
|
||||
224.15325927734375,
|
||||
-34.481563568115234
|
||||
],
|
||||
"size": [
|
||||
377.1661376953125,
|
||||
130
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "wan_t5_model",
|
||||
"type": "WANTEXTENCODER",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
15
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LoadWanVideoT5TextEncoder"
|
||||
},
|
||||
"widgets_values": [
|
||||
"umt5-xxl-enc-bf16.safetensors",
|
||||
"bf16",
|
||||
"offload_device",
|
||||
"disabled"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 28,
|
||||
"type": "WanVideoDecode",
|
||||
"pos": [
|
||||
1692.973876953125,
|
||||
-404.8614501953125
|
||||
],
|
||||
"size": [
|
||||
315,
|
||||
174
|
||||
],
|
||||
"flags": {},
|
||||
"order": 12,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "vae",
|
||||
"type": "WANVAE",
|
||||
"link": 43
|
||||
},
|
||||
{
|
||||
"name": "samples",
|
||||
"type": "LATENT",
|
||||
"link": 33
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
48
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoDecode"
|
||||
},
|
||||
"widgets_values": [
|
||||
true,
|
||||
272,
|
||||
272,
|
||||
144,
|
||||
128
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 38,
|
||||
"type": "WanVideoVAELoader",
|
||||
"pos": [
|
||||
1687.4093017578125,
|
||||
-582.2750854492188
|
||||
],
|
||||
"size": [
|
||||
416.25482177734375,
|
||||
82
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "vae",
|
||||
"type": "WANVAE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
43
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoVAELoader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"wanvideo\\Wan2_1_VAE_bf16.safetensors",
|
||||
"bf16"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 42,
|
||||
"type": "GetImageSizeAndCount",
|
||||
"pos": [
|
||||
1708.7301025390625,
|
||||
-140.99705505371094
|
||||
],
|
||||
"size": [
|
||||
277.20001220703125,
|
||||
86
|
||||
],
|
||||
"flags": {},
|
||||
"order": 13,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 48
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
56
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "832 width",
|
||||
"name": "width",
|
||||
"type": "INT",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"label": "480 height",
|
||||
"name": "height",
|
||||
"type": "INT",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"label": "257 count",
|
||||
"name": "count",
|
||||
"type": "INT",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GetImageSizeAndCount"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 16,
|
||||
"type": "WanVideoTextEncode",
|
||||
"pos": [
|
||||
675.8850708007812,
|
||||
-36.032100677490234
|
||||
],
|
||||
"size": [
|
||||
420.30511474609375,
|
||||
261.5306701660156
|
||||
],
|
||||
"flags": {},
|
||||
"order": 10,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "t5",
|
||||
"type": "WANTEXTENCODER",
|
||||
"link": 15
|
||||
},
|
||||
{
|
||||
"name": "model_to_offload",
|
||||
"shape": 7,
|
||||
"type": "WANVIDEOMODEL",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "text_embeds",
|
||||
"type": "WANVIDEOTEXTEMBEDS",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
30
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoTextEncode"
|
||||
},
|
||||
"widgets_values": [
|
||||
"high quality nature video featuring a red panda balancing on a bamboo stem while a bird lands on it's head, on the background there is a waterfall",
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
true
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"type": "VHS_VideoCombine",
|
||||
"pos": [
|
||||
2127.120849609375,
|
||||
-511.9014587402344
|
||||
],
|
||||
"size": [
|
||||
873.2135620117188,
|
||||
840.2385864257812
|
||||
],
|
||||
"flags": {},
|
||||
"order": 14,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": 56
|
||||
},
|
||||
{
|
||||
"name": "audio",
|
||||
"shape": 7,
|
||||
"type": "AUDIO",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "meta_batch",
|
||||
"shape": 7,
|
||||
"type": "VHS_BatchManager",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "vae",
|
||||
"shape": 7,
|
||||
"type": "VAE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "Filenames",
|
||||
"type": "VHS_FILENAMES",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VHS_VideoCombine"
|
||||
},
|
||||
"widgets_values": {
|
||||
"frame_rate": 16,
|
||||
"loop_count": 0,
|
||||
"filename_prefix": "WanVideo2_1_T2V",
|
||||
"format": "video/h264-mp4",
|
||||
"pix_fmt": "yuv420p",
|
||||
"crf": 19,
|
||||
"save_metadata": true,
|
||||
"trim_to_audio": false,
|
||||
"pingpong": false,
|
||||
"save_output": true,
|
||||
"videopreview": {
|
||||
"hidden": false,
|
||||
"paused": false,
|
||||
"params": {
|
||||
"filename": "WanVideo2_1_T2V_00412.mp4",
|
||||
"subfolder": "",
|
||||
"type": "output",
|
||||
"format": "video/h264-mp4",
|
||||
"frame_rate": 16,
|
||||
"workflow": "WanVideo2_1_T2V_00412.png",
|
||||
"fullpath": "N:\\AI\\ComfyUI\\output\\WanVideo2_1_T2V_00412.mp4"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"type": "WanVideoEmptyEmbeds",
|
||||
"pos": [
|
||||
1305.26708984375,
|
||||
-571.7843627929688
|
||||
],
|
||||
"size": [
|
||||
315,
|
||||
106
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "image_embeds",
|
||||
"type": "WANVIDIMAGE_EMBEDS",
|
||||
"links": [
|
||||
42
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoEmptyEmbeds"
|
||||
},
|
||||
"widgets_values": [
|
||||
832,
|
||||
480,
|
||||
257
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 35,
|
||||
"type": "WanVideoTorchCompileSettings",
|
||||
"pos": [
|
||||
193.47103881835938,
|
||||
-614.6900024414062
|
||||
],
|
||||
"size": [
|
||||
390.5999755859375,
|
||||
178
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "torch_compile_args",
|
||||
"type": "WANCOMPILEARGS",
|
||||
"slot_index": 0,
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoTorchCompileSettings"
|
||||
},
|
||||
"widgets_values": [
|
||||
"inductor",
|
||||
false,
|
||||
"default",
|
||||
false,
|
||||
64,
|
||||
true
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 45,
|
||||
"type": "WanVideoTeaCache",
|
||||
"pos": [
|
||||
931.4036865234375,
|
||||
-792.5159912109375
|
||||
],
|
||||
"size": [
|
||||
315,
|
||||
154
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "cache_args",
|
||||
"type": "CACHEARGS",
|
||||
"links": [
|
||||
58
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoTeaCache"
|
||||
},
|
||||
"widgets_values": [
|
||||
0.10000000000000002,
|
||||
1,
|
||||
-1,
|
||||
"offload_device",
|
||||
true
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 36,
|
||||
"type": "Note",
|
||||
"pos": [
|
||||
796.0189208984375,
|
||||
-521.5020751953125
|
||||
],
|
||||
"size": [
|
||||
298.2554016113281,
|
||||
108.62744140625
|
||||
],
|
||||
"flags": {},
|
||||
"order": 6,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"properties": {},
|
||||
"widgets_values": [
|
||||
"sdpa should work too, haven't tested flaash\n\nfp8_fast seems to cause huge quality degradation"
|
||||
],
|
||||
"color": "#432",
|
||||
"bgcolor": "#653"
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"type": "Note",
|
||||
"pos": [
|
||||
937.9556274414062,
|
||||
-940.750244140625
|
||||
],
|
||||
"size": [
|
||||
297.4364013671875,
|
||||
88
|
||||
],
|
||||
"flags": {},
|
||||
"order": 7,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"properties": {},
|
||||
"widgets_values": [
|
||||
"TeaCache with context windows is VERY experimental and lower values than normal should be used."
|
||||
],
|
||||
"color": "#432",
|
||||
"bgcolor": "#653"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"type": "WanVideoSampler",
|
||||
"pos": [
|
||||
1315.2401123046875,
|
||||
-401.48028564453125
|
||||
],
|
||||
"size": [
|
||||
315,
|
||||
574.1923217773438
|
||||
],
|
||||
"flags": {},
|
||||
"order": 11,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "WANVIDEOMODEL",
|
||||
"link": 29
|
||||
},
|
||||
{
|
||||
"name": "text_embeds",
|
||||
"type": "WANVIDEOTEXTEMBEDS",
|
||||
"link": 30
|
||||
},
|
||||
{
|
||||
"name": "image_embeds",
|
||||
"type": "WANVIDIMAGE_EMBEDS",
|
||||
"link": 42
|
||||
},
|
||||
{
|
||||
"name": "samples",
|
||||
"shape": 7,
|
||||
"type": "LATENT",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "feta_args",
|
||||
"shape": 7,
|
||||
"type": "FETAARGS",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "context_options",
|
||||
"shape": 7,
|
||||
"type": "WANVIDCONTEXT",
|
||||
"link": 57
|
||||
},
|
||||
{
|
||||
"name": "cache_args",
|
||||
"shape": 7,
|
||||
"type": "CACHEARGS",
|
||||
"link": 58
|
||||
},
|
||||
{
|
||||
"name": "flowedit_args",
|
||||
"shape": 7,
|
||||
"type": "FLOWEDITARGS",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "slg_args",
|
||||
"shape": 7,
|
||||
"type": "SLGARGS",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "loop_args",
|
||||
"shape": 7,
|
||||
"type": "LOOPARGS",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "samples",
|
||||
"type": "LATENT",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
33
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoSampler"
|
||||
},
|
||||
"widgets_values": [
|
||||
30,
|
||||
6,
|
||||
5,
|
||||
1057359483639288,
|
||||
"fixed",
|
||||
true,
|
||||
"unipc",
|
||||
0,
|
||||
1,
|
||||
"",
|
||||
"comfy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 43,
|
||||
"type": "WanVideoContextOptions",
|
||||
"pos": [
|
||||
1307.9542236328125,
|
||||
-855.8865356445312
|
||||
],
|
||||
"size": [
|
||||
315,
|
||||
226
|
||||
],
|
||||
"flags": {},
|
||||
"order": 8,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "vae",
|
||||
"shape": 7,
|
||||
"type": "WANVAE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "context_options",
|
||||
"type": "WANVIDCONTEXT",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
57
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoContextOptions"
|
||||
},
|
||||
"widgets_values": [
|
||||
"uniform_standard",
|
||||
81,
|
||||
4,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
6,
|
||||
2
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 22,
|
||||
"type": "WanVideoModelLoader",
|
||||
"pos": [
|
||||
620.3950805664062,
|
||||
-357.8426818847656
|
||||
],
|
||||
"size": [
|
||||
477.4410095214844,
|
||||
226.43276977539062
|
||||
],
|
||||
"flags": {},
|
||||
"order": 9,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "compile_args",
|
||||
"shape": 7,
|
||||
"type": "WANCOMPILEARGS",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "block_swap_args",
|
||||
"shape": 7,
|
||||
"type": "BLOCKSWAPARGS",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "lora",
|
||||
"shape": 7,
|
||||
"type": "WANVIDLORA",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "vram_management_args",
|
||||
"shape": 7,
|
||||
"type": "VRAM_MANAGEMENTARGS",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "WANVIDEOMODEL",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
29
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "WanVideoModelLoader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"WanVideo\\wan2.1_t2v_1.3B_fp16.safetensors",
|
||||
"fp16",
|
||||
"disabled",
|
||||
"offload_device",
|
||||
"sdpa"
|
||||
]
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
[
|
||||
15,
|
||||
11,
|
||||
0,
|
||||
16,
|
||||
0,
|
||||
"WANTEXTENCODER"
|
||||
],
|
||||
[
|
||||
29,
|
||||
22,
|
||||
0,
|
||||
27,
|
||||
0,
|
||||
"WANVIDEOMODEL"
|
||||
],
|
||||
[
|
||||
30,
|
||||
16,
|
||||
0,
|
||||
27,
|
||||
1,
|
||||
"WANVIDEOTEXTEMBEDS"
|
||||
],
|
||||
[
|
||||
33,
|
||||
27,
|
||||
0,
|
||||
28,
|
||||
1,
|
||||
"LATENT"
|
||||
],
|
||||
[
|
||||
42,
|
||||
37,
|
||||
0,
|
||||
27,
|
||||
2,
|
||||
"WANVIDIMAGE_EMBEDS"
|
||||
],
|
||||
[
|
||||
43,
|
||||
38,
|
||||
0,
|
||||
28,
|
||||
0,
|
||||
"VAE"
|
||||
],
|
||||
[
|
||||
48,
|
||||
28,
|
||||
0,
|
||||
42,
|
||||
0,
|
||||
"IMAGE"
|
||||
],
|
||||
[
|
||||
56,
|
||||
42,
|
||||
0,
|
||||
30,
|
||||
0,
|
||||
"IMAGE"
|
||||
],
|
||||
[
|
||||
57,
|
||||
43,
|
||||
0,
|
||||
27,
|
||||
5,
|
||||
"WANVIDCONTEXT"
|
||||
],
|
||||
[
|
||||
58,
|
||||
45,
|
||||
0,
|
||||
27,
|
||||
6,
|
||||
"TEACACHEARGS"
|
||||
]
|
||||
],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 0.8140274938684471,
|
||||
"offset": [
|
||||
-122.25834160503663,
|
||||
993.5739491626379
|
||||
]
|
||||
},
|
||||
"node_versions": {
|
||||
"ComfyUI-WanVideoWrapper": "5a2383621a05825d0d0437781afcb8552d9590fd",
|
||||
"ComfyUI-KJNodes": "a5bd3c86c8ed6b83c55c2d0e7a59515b15a0137f",
|
||||
"ComfyUI-VideoHelperSuite": "0a75c7958fe320efcb052f1d9f8451fd20c730a8"
|
||||
},
|
||||
"VHS_latentpreview": true,
|
||||
"VHS_latentpreviewrate": 0,
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
483
multitalk/multitalk_loop.py
Normal file
483
multitalk/multitalk_loop.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import torch
|
||||
import os
|
||||
import gc
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from ..latent_preview import prepare_callback
|
||||
from ..wanvideo.schedulers import get_scheduler
|
||||
from .multitalk import timestep_transform, add_noise
|
||||
from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap
|
||||
from comfy.utils import load_torch_file
|
||||
from ..nodes_model_loading import load_weights
|
||||
from ..HuMo.nodes import get_audio_emb_window
|
||||
import comfy.model_management as mm
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
|
||||
VAE_STRIDE = (4, 8, 8)
|
||||
PATCH_SIZE = (1, 2, 2)
|
||||
vae_upscale_factor = 16
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
def multitalk_loop(self, **kwargs):
|
||||
# Unpack kwargs into local variables
|
||||
(latent, total_steps, steps, start_step, end_step, shift, cfg, denoise_strength,
|
||||
sigmas, weight_dtype, transformer, patcher, block_swap_args, model, vae, dtype,
|
||||
scheduler, scheduler_step_args, text_embeds, image_embeds, multitalk_embeds,
|
||||
multitalk_audio_embeds, unianim_data, dwpose_data, unianimate_poses, uni3c_embeds,
|
||||
humo_image_cond, humo_image_cond_neg, humo_audio, humo_reference_count,
|
||||
add_noise_to_samples, audio_stride, use_tsr, tsr_k, tsr_sigma, fantasy_portrait_input,
|
||||
noise, timesteps, force_offload, add_cond, control_latents, audio_proj,
|
||||
control_camera_latents, samples, masks, seed_g, gguf_reader, predict_func
|
||||
) = (kwargs.get(k) for k in (
|
||||
'latent', 'total_steps', 'steps', 'start_step', 'end_step', 'shift', 'cfg',
|
||||
'denoise_strength', 'sigmas', 'weight_dtype', 'transformer', 'patcher',
|
||||
'block_swap_args', 'model', 'vae', 'dtype', 'scheduler', 'scheduler_step_args',
|
||||
'text_embeds', 'image_embeds', 'multitalk_embeds', 'multitalk_audio_embeds',
|
||||
'unianim_data', 'dwpose_data', 'unianimate_poses', 'uni3c_embeds',
|
||||
'humo_image_cond', 'humo_image_cond_neg', 'humo_audio', 'humo_reference_count',
|
||||
'add_noise_to_samples', 'audio_stride', 'use_tsr', 'tsr_k', 'tsr_sigma',
|
||||
'fantasy_portrait_input', 'noise', 'timesteps', 'force_offload', 'add_cond',
|
||||
'control_latents', 'audio_proj', 'control_camera_latents', 'samples', 'masks',
|
||||
'seed_g', 'gguf_reader', 'predict_with_cfg'
|
||||
))
|
||||
|
||||
mode = image_embeds.get("multitalk_mode", "multitalk")
|
||||
if mode == "auto":
|
||||
mode = transformer.multitalk_model_type.lower()
|
||||
log.info(f"Multitalk mode: {mode}")
|
||||
cond_frame = None
|
||||
offload = image_embeds.get("force_offload", False)
|
||||
offloaded = False
|
||||
tiled_vae = image_embeds.get("tiled_vae", False)
|
||||
frame_num = clip_length = image_embeds.get("frame_window_size", 81)
|
||||
|
||||
clip_embeds = image_embeds.get("clip_context", None)
|
||||
if clip_embeds is not None:
|
||||
clip_embeds = clip_embeds.to(dtype)
|
||||
colormatch = image_embeds.get("colormatch", "disabled")
|
||||
motion_frame = image_embeds.get("motion_frame", 25)
|
||||
target_w = image_embeds.get("target_w", None)
|
||||
target_h = image_embeds.get("target_h", None)
|
||||
original_images = cond_image = image_embeds.get("multitalk_start_image", None)
|
||||
if original_images is None:
|
||||
original_images = torch.zeros([noise.shape[0], 1, target_h, target_w], device=device)
|
||||
|
||||
output_path = image_embeds.get("output_path", "")
|
||||
img_counter = 0
|
||||
|
||||
if len(multitalk_embeds['audio_features'])==2 and (multitalk_embeds['ref_target_masks'] is None):
|
||||
face_scale = 0.1
|
||||
x_min, x_max = int(target_h * face_scale), int(target_h * (1 - face_scale))
|
||||
lefty_min, lefty_max = int((target_w//2) * face_scale), int((target_w//2) * (1 - face_scale))
|
||||
righty_min, righty_max = int((target_w//2) * face_scale + (target_w//2)), int((target_w//2) * (1 - face_scale) + (target_w//2))
|
||||
human_mask1, human_mask2 = (torch.zeros([target_h, target_w]) for _ in range(2))
|
||||
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
|
||||
human_mask2[x_min:x_max, righty_min:righty_max] = 1
|
||||
background_mask = torch.where((human_mask1 + human_mask2) > 0, torch.tensor(0), torch.tensor(1))
|
||||
human_masks = [human_mask1, human_mask2, background_mask]
|
||||
ref_target_masks = torch.stack(human_masks, dim=0)
|
||||
multitalk_embeds['ref_target_masks'] = ref_target_masks
|
||||
|
||||
gen_video_list = []
|
||||
is_first_clip = True
|
||||
arrive_last_frame = False
|
||||
cur_motion_frames_num = 1
|
||||
audio_start_idx = iteration_count = step_iteration_count = 0
|
||||
audio_end_idx = (audio_start_idx + clip_length) * audio_stride
|
||||
indices = (torch.arange(4 + 1) - 2) * 1
|
||||
current_condframe_index = 0
|
||||
|
||||
audio_embedding = multitalk_audio_embeds
|
||||
human_num = len(audio_embedding)
|
||||
audio_embs = None
|
||||
cond_frame = None
|
||||
|
||||
uni3c_data = None
|
||||
if uni3c_embeds is not None:
|
||||
transformer.controlnet = uni3c_embeds["controlnet"]
|
||||
uni3c_data = uni3c_embeds.copy()
|
||||
|
||||
encoded_silence = None
|
||||
|
||||
try:
|
||||
silence_path = os.path.join(script_directory, "encoded_silence.safetensors")
|
||||
encoded_silence = load_torch_file(silence_path)["audio_emb"].to(dtype)
|
||||
except:
|
||||
log.warning("No encoded silence file found, padding with end of audio embedding instead.")
|
||||
|
||||
total_frames = len(audio_embedding[0])
|
||||
estimated_iterations = total_frames // (frame_num - motion_frame) + 1
|
||||
callback = prepare_callback(patcher, estimated_iterations)
|
||||
|
||||
if frame_num >= total_frames:
|
||||
arrive_last_frame = True
|
||||
estimated_iterations = 1
|
||||
|
||||
log.info(f"Sampling {total_frames} frames in {estimated_iterations} windows, at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps")
|
||||
|
||||
while True: # start video generation iteratively
|
||||
self.cache_state = [None, None]
|
||||
|
||||
cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4)
|
||||
if mode == "infinitetalk":
|
||||
cond_image = original_images[:, :, current_condframe_index:current_condframe_index+1] if cond_image is not None else None
|
||||
if multitalk_embeds is not None:
|
||||
audio_embs = []
|
||||
# split audio with window size
|
||||
for human_idx in range(human_num):
|
||||
center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0]-1)
|
||||
audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device)
|
||||
audio_embs.append(audio_emb)
|
||||
audio_embs = torch.concat(audio_embs, dim=0).to(dtype)
|
||||
|
||||
h, w = (cond_image.shape[-2], cond_image.shape[-1]) if cond_image is not None else (target_h, target_w)
|
||||
lat_h, lat_w = h // VAE_STRIDE[1], w // VAE_STRIDE[2]
|
||||
latent_frame_num = (frame_num - 1) // 4 + 1
|
||||
|
||||
noise = torch.randn(
|
||||
16, latent_frame_num,
|
||||
lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device)
|
||||
|
||||
# Calculate the correct latent slice based on current iteration
|
||||
if is_first_clip:
|
||||
latent_start_idx = 0
|
||||
latent_end_idx = noise.shape[1]
|
||||
else:
|
||||
new_frames_per_iteration = frame_num - motion_frame
|
||||
new_latent_frames_per_iteration = ((new_frames_per_iteration - 1) // 4 + 1)
|
||||
latent_start_idx = iteration_count * new_latent_frames_per_iteration
|
||||
latent_end_idx = latent_start_idx + noise.shape[1]
|
||||
|
||||
if samples is not None:
|
||||
noise_mask = samples.get("noise_mask", None)
|
||||
input_samples = samples["samples"]
|
||||
if input_samples is not None:
|
||||
input_samples = input_samples.squeeze(0).to(noise)
|
||||
# Check if we have enough frames in input_samples
|
||||
if latent_end_idx > input_samples.shape[1]:
|
||||
# We need more frames than available - pad the input_samples at the end
|
||||
pad_length = latent_end_idx - input_samples.shape[1]
|
||||
last_frame = input_samples[:, -1:].repeat(1, pad_length, 1, 1)
|
||||
input_samples = torch.cat([input_samples, last_frame], dim=1)
|
||||
input_samples = input_samples[:, latent_start_idx:latent_end_idx]
|
||||
if noise_mask is not None:
|
||||
original_image = input_samples.to(device)
|
||||
|
||||
assert input_samples.shape[1] == noise.shape[1], f"Slice mismatch: {input_samples.shape[1]} vs {noise.shape[1]}"
|
||||
|
||||
if add_noise_to_samples:
|
||||
latent_timestep = timesteps[0]
|
||||
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
|
||||
else:
|
||||
noise = input_samples
|
||||
|
||||
# diff diff prep
|
||||
if noise_mask is not None:
|
||||
if len(noise_mask.shape) == 4:
|
||||
noise_mask = noise_mask.squeeze(1)
|
||||
if audio_end_idx > noise_mask.shape[0]:
|
||||
noise_mask = noise_mask.repeat(audio_end_idx // noise_mask.shape[0], 1, 1)
|
||||
noise_mask = noise_mask[audio_start_idx:audio_end_idx]
|
||||
noise_mask = torch.nn.functional.interpolate(
|
||||
noise_mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W]
|
||||
size=(noise.shape[1], noise.shape[2], noise.shape[3]),
|
||||
mode='trilinear',
|
||||
align_corners=False
|
||||
).repeat(1, noise.shape[0], 1, 1, 1)
|
||||
|
||||
thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps)
|
||||
thresholds = thresholds.reshape(-1, 1, 1, 1, 1).to(device)
|
||||
masks = (1-noise_mask.repeat(len(timesteps), 1, 1, 1, 1).to(device)) > thresholds
|
||||
|
||||
# zero padding and vae encode for img cond
|
||||
if cond_image is not None or cond_frame is not None:
|
||||
cond_ = cond_image if (is_first_clip or humo_image_cond is None) else cond_frame
|
||||
cond_frame_num = cond_.shape[2]
|
||||
video_frames = torch.zeros(1, 3, frame_num-cond_frame_num, target_h, target_w, device=device, dtype=vae.dtype)
|
||||
padding_frames_pixels_values = torch.concat([cond_.to(device, vae.dtype), video_frames], dim=2)
|
||||
|
||||
# encode
|
||||
vae.to(device)
|
||||
y = vae.encode(padding_frames_pixels_values, device=device, tiled=tiled_vae, pbar=False).to(dtype)[0]
|
||||
|
||||
if mode == "multitalk":
|
||||
latent_motion_frames = y[:, :cur_motion_frames_latent_num] # C T H W
|
||||
else:
|
||||
cond_ = cond_image if is_first_clip else cond_frame
|
||||
latent_motion_frames = vae.encode(cond_.to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False).to(dtype)[0]
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
#motion_frame_index = cur_motion_frames_latent_num if mode == "infinitetalk" else 1
|
||||
msk = torch.zeros(4, latent_frame_num, lat_h, lat_w, device=device, dtype=dtype)
|
||||
msk[:, :1] = 1
|
||||
y = torch.cat([msk, y]) # 4+C T H W
|
||||
mm.soft_empty_cache()
|
||||
else:
|
||||
y = None
|
||||
latent_motion_frames = noise[:, :1]
|
||||
|
||||
partial_humo_cond_input = partial_humo_cond_neg_input = partial_humo_audio = partial_humo_audio_neg = None
|
||||
if humo_image_cond is not None:
|
||||
partial_humo_cond_input = humo_image_cond[:, :latent_frame_num]
|
||||
partial_humo_cond_neg_input = humo_image_cond_neg[:, :latent_frame_num]
|
||||
if y is not None:
|
||||
partial_humo_cond_input[:, :1] = y[:, :1]
|
||||
if humo_reference_count > 0:
|
||||
partial_humo_cond_input[:, -humo_reference_count:] = humo_image_cond[:, -humo_reference_count:]
|
||||
partial_humo_cond_neg_input[:, -humo_reference_count:] = humo_image_cond_neg[:, -humo_reference_count:]
|
||||
|
||||
if humo_audio is not None:
|
||||
if is_first_clip:
|
||||
audio_embs = None
|
||||
|
||||
partial_humo_audio, _ = get_audio_emb_window(humo_audio, frame_num, frame0_idx=audio_start_idx)
|
||||
#zero_audio_pad = torch.zeros(humo_reference_count, *partial_humo_audio.shape[1:], device=partial_humo_audio.device, dtype=partial_humo_audio.dtype)
|
||||
partial_humo_audio[-humo_reference_count:] = 0
|
||||
partial_humo_audio_neg = torch.zeros_like(partial_humo_audio, device=partial_humo_audio.device, dtype=partial_humo_audio.dtype)
|
||||
|
||||
if scheduler == "multitalk":
|
||||
timesteps = list(np.linspace(1000, 1, steps, dtype=np.float32))
|
||||
timesteps.append(0.)
|
||||
timesteps = [torch.tensor([t], device=device) for t in timesteps]
|
||||
timesteps = [timestep_transform(t, shift=shift, num_timesteps=1000) for t in timesteps]
|
||||
else:
|
||||
if isinstance(scheduler, dict):
|
||||
sample_scheduler = copy.deepcopy(scheduler["sample_scheduler"])
|
||||
timesteps = scheduler["timesteps"]
|
||||
else:
|
||||
sample_scheduler, timesteps,_,_ = get_scheduler(scheduler, total_steps, start_step, end_step, shift, device, transformer.dim, denoise_strength, sigmas=sigmas)
|
||||
timesteps = [torch.tensor([float(t)], device=device) for t in timesteps] + [torch.tensor([0.], device=device)]
|
||||
|
||||
# sample videos
|
||||
latent = noise
|
||||
|
||||
# injecting motion frames
|
||||
if not is_first_clip and mode == "multitalk":
|
||||
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device)
|
||||
motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous()
|
||||
add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[0])
|
||||
latent[:, :add_latent.shape[1]] = add_latent
|
||||
|
||||
if offloaded:
|
||||
# Load weights
|
||||
if transformer.patched_linear and gguf_reader is None:
|
||||
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
|
||||
elif gguf_reader is not None: #handle GGUF
|
||||
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
|
||||
#blockswap init
|
||||
init_blockswap(transformer, block_swap_args, model)
|
||||
|
||||
# Use the appropriate prompt for this section
|
||||
if len(text_embeds["prompt_embeds"]) > 1:
|
||||
prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1)
|
||||
positive = [text_embeds["prompt_embeds"][prompt_index]]
|
||||
log.info(f"Using prompt index: {prompt_index}")
|
||||
else:
|
||||
positive = text_embeds["prompt_embeds"]
|
||||
|
||||
# uni3c slices
|
||||
if uni3c_embeds is not None:
|
||||
vae.to(device)
|
||||
# Pad original_images if needed
|
||||
num_frames = original_images.shape[2]
|
||||
if audio_end_idx > num_frames:
|
||||
pad_len = audio_end_idx - num_frames
|
||||
last_frame = original_images[:, :, -1:].repeat(1, 1, pad_len, 1, 1)
|
||||
padded_images = torch.cat([original_images, last_frame], dim=2)
|
||||
else:
|
||||
padded_images = original_images
|
||||
render_latent = vae.encode(
|
||||
padded_images[:, :, audio_start_idx:audio_end_idx].to(device, vae.dtype),
|
||||
device=device, tiled=tiled_vae
|
||||
).to(dtype)
|
||||
|
||||
vae.to(offload_device)
|
||||
uni3c_data['render_latent'] = render_latent
|
||||
|
||||
# unianimate slices
|
||||
partial_unianim_data = None
|
||||
if unianim_data is not None:
|
||||
partial_dwpose = dwpose_data[:, :, latent_start_idx:latent_end_idx]
|
||||
partial_unianim_data = {
|
||||
"dwpose": partial_dwpose,
|
||||
"random_ref": unianim_data["random_ref"],
|
||||
"strength": unianimate_poses["strength"],
|
||||
"start_percent": unianimate_poses["start_percent"],
|
||||
"end_percent": unianimate_poses["end_percent"]
|
||||
}
|
||||
|
||||
# fantasy portrait slices
|
||||
partial_fantasy_portrait_input = None
|
||||
if fantasy_portrait_input is not None:
|
||||
adapter_proj = fantasy_portrait_input["adapter_proj"]
|
||||
if latent_end_idx > adapter_proj.shape[1]:
|
||||
pad_len = latent_end_idx - adapter_proj.shape[1]
|
||||
last_frame = adapter_proj[:, -1:, :, :].repeat(1, pad_len, 1, 1)
|
||||
padded_proj = torch.cat([adapter_proj, last_frame], dim=1)
|
||||
else:
|
||||
padded_proj = adapter_proj
|
||||
partial_fantasy_portrait_input = fantasy_portrait_input.copy()
|
||||
partial_fantasy_portrait_input["adapter_proj"] = padded_proj[:, latent_start_idx:latent_end_idx]
|
||||
|
||||
mm.soft_empty_cache()
|
||||
gc.collect()
|
||||
# sampling loop
|
||||
sampling_pbar = tqdm(total=len(timesteps)-1, desc=f"Sampling audio indices {audio_start_idx}-{audio_end_idx}", position=0, leave=True)
|
||||
for i in range(len(timesteps)-1):
|
||||
timestep = timesteps[i]
|
||||
latent_model_input = latent.to(device)
|
||||
if mode == "infinitetalk":
|
||||
if humo_image_cond is None or not is_first_clip:
|
||||
latent_model_input[:, :cur_motion_frames_latent_num] = latent_motion_frames
|
||||
|
||||
noise_pred, _, self.cache_state = predict_func(
|
||||
latent_model_input, cfg[min(i, len(timesteps)-1)], positive, text_embeds["negative_prompt_embeds"],
|
||||
timestep, i, y, clip_embeds, control_latents, None, partial_unianim_data, audio_proj, control_camera_latents, add_cond,
|
||||
cache_state=self.cache_state, multitalk_audio_embeds=audio_embs, fantasy_portrait_input=partial_fantasy_portrait_input,
|
||||
humo_image_cond=partial_humo_cond_input, humo_image_cond_neg=partial_humo_cond_neg_input, humo_audio=partial_humo_audio, humo_audio_neg=partial_humo_audio_neg,
|
||||
uni3c_data = uni3c_data)
|
||||
|
||||
if callback is not None:
|
||||
callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * timestep.to(device) / 1000).detach().permute(1,0,2,3)
|
||||
callback(step_iteration_count, callback_latent, None, estimated_iterations*(len(timesteps)-1))
|
||||
del callback_latent
|
||||
|
||||
sampling_pbar.update(1)
|
||||
step_iteration_count += 1
|
||||
|
||||
# update latent
|
||||
if use_tsr:
|
||||
noise_pred = temporal_score_rescaling(noise_pred, latent, timestep, tsr_k, tsr_sigma)
|
||||
if scheduler == "multitalk":
|
||||
noise_pred = -noise_pred
|
||||
dt = (timesteps[i] - timesteps[i + 1]) / 1000
|
||||
latent = latent + noise_pred * dt[:, None, None, None]
|
||||
else:
|
||||
latent = sample_scheduler.step(noise_pred.unsqueeze(0), timestep, latent.unsqueeze(0).to(noise_pred.device), **scheduler_step_args)[0].squeeze(0)
|
||||
del noise_pred, latent_model_input, timestep
|
||||
|
||||
# differential diffusion inpaint
|
||||
if masks is not None:
|
||||
if i < len(timesteps) - 1:
|
||||
image_latent = add_noise(original_image.to(device), noise.to(device), timesteps[i+1])
|
||||
mask = masks[i].to(latent)
|
||||
latent = image_latent * mask + latent * (1-mask)
|
||||
|
||||
# injecting motion frames
|
||||
if not is_first_clip and mode == "multitalk":
|
||||
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device)
|
||||
motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous()
|
||||
add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1])
|
||||
latent[:, :add_latent.shape[1]] = add_latent
|
||||
else:
|
||||
if humo_image_cond is None or not is_first_clip:
|
||||
latent[:, :cur_motion_frames_latent_num] = latent_motion_frames
|
||||
|
||||
del noise, latent_motion_frames
|
||||
if offload:
|
||||
offload_transformer(transformer, remove_lora=False)
|
||||
offloaded = True
|
||||
if humo_image_cond is not None and humo_reference_count > 0:
|
||||
latent = latent[:,:-humo_reference_count]
|
||||
vae.to(device)
|
||||
videos = vae.decode(latent.unsqueeze(0).to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False)[0].cpu()
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
sampling_pbar.close()
|
||||
|
||||
# optional color correction (less relevant for InfiniteTalk)
|
||||
if colormatch != "disabled":
|
||||
videos = videos.permute(1, 2, 3, 0).float().numpy()
|
||||
from color_matcher import ColorMatcher
|
||||
cm = ColorMatcher()
|
||||
cm_result_list = []
|
||||
for img in videos:
|
||||
if mode == "multitalk":
|
||||
cm_result = cm.transfer(src=img, ref=original_images[0].permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
|
||||
else:
|
||||
cm_result = cm.transfer(src=img, ref=cond_image[0].permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
|
||||
cm_result_list.append(torch.from_numpy(cm_result).to(vae.dtype))
|
||||
|
||||
videos = torch.stack(cm_result_list, dim=0).permute(3, 0, 1, 2)
|
||||
|
||||
# optionally save generated samples to disk
|
||||
if output_path:
|
||||
video_np = videos.clamp(-1.0, 1.0).add(1.0).div(2.0).mul(255).cpu().float().numpy().transpose(1, 2, 3, 0).astype('uint8')
|
||||
num_frames_to_save = video_np.shape[0] if is_first_clip else video_np.shape[0] - cur_motion_frames_num
|
||||
log.info(f"Saving {num_frames_to_save} generated frames to {output_path}")
|
||||
start_idx = 0 if is_first_clip else cur_motion_frames_num
|
||||
for i in range(start_idx, video_np.shape[0]):
|
||||
im = Image.fromarray(video_np[i])
|
||||
im.save(os.path.join(output_path, f"frame_{img_counter:05d}.png"))
|
||||
img_counter += 1
|
||||
else:
|
||||
gen_video_list.append(videos if is_first_clip else videos[:, cur_motion_frames_num:])
|
||||
|
||||
current_condframe_index += 1
|
||||
iteration_count += 1
|
||||
|
||||
# decide whether is done
|
||||
if arrive_last_frame:
|
||||
break
|
||||
|
||||
# update next condition frames
|
||||
is_first_clip = False
|
||||
cur_motion_frames_num = motion_frame
|
||||
|
||||
cond_ = videos[:, -cur_motion_frames_num:].unsqueeze(0)
|
||||
if mode == "infinitetalk":
|
||||
cond_frame = cond_
|
||||
else:
|
||||
cond_image = cond_
|
||||
|
||||
del videos, latent
|
||||
|
||||
# Repeat audio emb
|
||||
if multitalk_embeds is not None:
|
||||
audio_start_idx += (frame_num - cur_motion_frames_num - humo_reference_count)
|
||||
audio_end_idx = audio_start_idx + clip_length
|
||||
if audio_end_idx >= len(audio_embedding[0]):
|
||||
arrive_last_frame = True
|
||||
miss_lengths = []
|
||||
source_frames = []
|
||||
for human_inx in range(human_num):
|
||||
source_frame = len(audio_embedding[human_inx])
|
||||
source_frames.append(source_frame)
|
||||
if audio_end_idx >= len(audio_embedding[human_inx]):
|
||||
log.warning(f"Audio embedding for subject {human_inx} not long enough: {len(audio_embedding[human_inx])}, need {audio_end_idx}, padding...")
|
||||
miss_length = audio_end_idx - len(audio_embedding[human_inx]) + 3
|
||||
log.warning(f"Padding length: {miss_length}")
|
||||
if encoded_silence is not None:
|
||||
add_audio_emb = encoded_silence[-1*miss_length:]
|
||||
else:
|
||||
add_audio_emb = torch.flip(audio_embedding[human_inx][-1*miss_length:], dims=[0])
|
||||
audio_embedding[human_inx] = torch.cat([audio_embedding[human_inx], add_audio_emb.to(device, dtype)], dim=0)
|
||||
miss_lengths.append(miss_length)
|
||||
else:
|
||||
miss_lengths.append(0)
|
||||
if mode == "infinitetalk" and current_condframe_index >= original_images.shape[2]:
|
||||
last_frame = original_images[:, :, -1:, :, :]
|
||||
miss_length = 1
|
||||
original_images = torch.cat([original_images, last_frame.repeat(1, 1, miss_length, 1, 1)], dim=2)
|
||||
|
||||
if not output_path:
|
||||
gen_video_samples = torch.cat(gen_video_list, dim=1)
|
||||
else:
|
||||
gen_video_samples = torch.zeros(3, 1, 64, 64) # dummy output
|
||||
|
||||
if force_offload:
|
||||
if not model["auto_cpu_offload"]:
|
||||
offload_transformer(transformer)
|
||||
try:
|
||||
print_memory(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path},
|
||||
111
nodes.py
111
nodes.py
@@ -888,6 +888,83 @@ class WanVideoAddMTVMotion:
|
||||
updated["mtv_crafter_motion"] = new_entry
|
||||
return (updated,)
|
||||
|
||||
class WanVideoAddStoryMemLatents:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"vae": ("WANVAE",),
|
||||
"embeds": ("WANVIDIMAGE_EMBEDS",),
|
||||
"memory_images": ("IMAGE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
|
||||
RETURN_NAMES = ("image_embeds",)
|
||||
FUNCTION = "add"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def add(self, vae, embeds, memory_images):
|
||||
updated = dict(embeds)
|
||||
story_mem_latents, = WanVideoEncodeLatentBatch().encode(vae, memory_images)
|
||||
updated["story_mem_latents"] = story_mem_latents["samples"].squeeze(2).permute(1, 0, 2, 3) # [C, T, H, W]
|
||||
return (updated,)
|
||||
|
||||
|
||||
class WanVideoSVIProEmbeds:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"anchor_samples": ("LATENT", {"tooltip": "Initial start image encoded"}),
|
||||
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_samples": ("LATENT", {"tooltip": "Last latent from previous generation"}),
|
||||
"motion_latent_count": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1, "tooltip": "Number of latents used to continue"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
|
||||
RETURN_NAMES = ("image_embeds",)
|
||||
FUNCTION = "add"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def add(self, anchor_samples, num_frames, prev_samples=None, motion_latent_count=1):
|
||||
|
||||
anchor_latent = anchor_samples["samples"][0].clone()
|
||||
|
||||
C, T, H, W = anchor_latent.shape
|
||||
|
||||
total_latents = (num_frames - 1) // 4 + 1
|
||||
device = anchor_latent.device
|
||||
dtype = anchor_latent.dtype
|
||||
|
||||
if prev_samples is None or motion_latent_count == 0:
|
||||
padding_size = total_latents - anchor_latent.shape[1]
|
||||
padding = torch.zeros(C, padding_size, H, W, dtype=dtype, device=device)
|
||||
y = torch.concat([anchor_latent, padding], dim=1)
|
||||
else:
|
||||
prev_latent = prev_samples["samples"][0].clone()
|
||||
motion_latent = prev_latent[:, -motion_latent_count:]
|
||||
padding_size = total_latents - anchor_latent.shape[1] - motion_latent.shape[1]
|
||||
padding = torch.zeros(C, padding_size, H, W, dtype=dtype, device=device)
|
||||
y = torch.concat([anchor_latent, motion_latent, padding], dim=1)
|
||||
|
||||
msk = torch.ones(1, num_frames, H, W, device=device, dtype=dtype)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, H, W)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
image_embeds = {
|
||||
"image_embeds": y,
|
||||
"num_frames": num_frames,
|
||||
"lat_h": H,
|
||||
"lat_w": W,
|
||||
"mask": msk
|
||||
}
|
||||
|
||||
return (image_embeds,)
|
||||
|
||||
#region I2V encode
|
||||
class WanVideoImageToVideoEncode:
|
||||
@classmethod
|
||||
@@ -1826,33 +1903,7 @@ class WanVideoContextOptions:
|
||||
}
|
||||
|
||||
return (context_options,)
|
||||
|
||||
|
||||
class WanVideoFlowEdit:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"source_embeds": ("WANVIDEOTEXTEMBEDS", ),
|
||||
"skip_steps": ("INT", {"default": 4, "min": 0}),
|
||||
"drift_steps": ("INT", {"default": 0, "min": 0}),
|
||||
"drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}),
|
||||
"source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||
"drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {
|
||||
"source_image_embeds": ("WANVIDIMAGE_EMBEDS", ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOWEDITARGS", )
|
||||
RETURN_NAMES = ("flowedit_args",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
DESCRIPTION = "Flowedit options for WanVideo"
|
||||
|
||||
def process(self, **kwargs):
|
||||
return (kwargs,)
|
||||
|
||||
class WanVideoLoopArgs:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -2122,7 +2173,7 @@ class WanVideoEncodeLatentBatch:
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
DESCRIPTION = "Encodes a batch of images individually to create a latent video batch where each video is a single frame, useful for I2V init purposes, for example as multiple context window inits"
|
||||
|
||||
def encode(self, vae, images, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, latent_strength=1.0):
|
||||
def encode(self, vae, images, enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128, latent_strength=1.0):
|
||||
vae.to(device)
|
||||
|
||||
images = images.clone()
|
||||
@@ -2228,7 +2279,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoEnhanceAVideo": WanVideoEnhanceAVideo,
|
||||
"WanVideoContextOptions": WanVideoContextOptions,
|
||||
"WanVideoTextEmbedBridge": WanVideoTextEmbedBridge,
|
||||
"WanVideoFlowEdit": WanVideoFlowEdit,
|
||||
"WanVideoControlEmbeds": WanVideoControlEmbeds,
|
||||
"WanVideoSLG": WanVideoSLG,
|
||||
"WanVideoLoopArgs": WanVideoLoopArgs,
|
||||
@@ -2255,6 +2305,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"TextImageEncodeQwenVL": TextImageEncodeQwenVL,
|
||||
"WanVideoUniLumosEmbeds": WanVideoUniLumosEmbeds,
|
||||
"WanVideoAddTTMLatents": WanVideoAddTTMLatents,
|
||||
"WanVideoAddStoryMemLatents": WanVideoAddStoryMemLatents,
|
||||
"WanVideoSVIProEmbeds": WanVideoSVIProEmbeds,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -2270,7 +2322,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video",
|
||||
"WanVideoContextOptions": "WanVideo Context Options",
|
||||
"WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge",
|
||||
"WanVideoFlowEdit": "WanVideo FlowEdit",
|
||||
"WanVideoControlEmbeds": "WanVideo Control Embeds",
|
||||
"WanVideoSLG": "WanVideo SLG",
|
||||
"WanVideoLoopArgs": "WanVideo Loop Args",
|
||||
@@ -2296,4 +2347,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoAddBindweaveEmbeds": "WanVideo Add Bindweave Embeds",
|
||||
"WanVideoUniLumosEmbeds": "WanVideo UniLumos Embeds",
|
||||
"WanVideoAddTTMLatents": "WanVideo Add TTMLatents",
|
||||
"WanVideoAddStoryMemLatents": "WanVideo Add StoryMem Latents",
|
||||
"WanVideoSVIProEmbeds": "WanVideo SVIPro Embeds",
|
||||
}
|
||||
|
||||
@@ -36,6 +36,9 @@ try:
|
||||
except:
|
||||
PromptServer = None
|
||||
|
||||
attention_modes = ["sdpa", "flash_attn_2", "flash_attn_3", "sageattn", "sageattn_3", "radial_sage_attention", "sageattn_compiled",
|
||||
"sageattn_ultravico", "comfy"]
|
||||
|
||||
#from city96's gguf nodes
|
||||
def update_folder_names_and_paths(key, targets=[]):
|
||||
# check for existing key
|
||||
@@ -178,7 +181,7 @@ def standardize_lora_key_format(lora_sd):
|
||||
|
||||
new_key += f".{component}"
|
||||
|
||||
# Handle weight type - this is the critical fix
|
||||
# Handle weight type
|
||||
if weight_type:
|
||||
if weight_type == 'alpha':
|
||||
new_key += '.alpha'
|
||||
@@ -209,12 +212,12 @@ def standardize_lora_key_format(lora_sd):
|
||||
new_key = new_key.replace('time_embedding', 'time.embedding')
|
||||
new_key = new_key.replace('time_projection', 'time.projection')
|
||||
|
||||
# Replace remaining underscores with dots, carefully
|
||||
# Replace remaining underscores with dots
|
||||
parts = new_key.split('.')
|
||||
final_parts = []
|
||||
for part in parts:
|
||||
if part in ['img_emb', 'self_attn', 'cross_attn']:
|
||||
final_parts.append(part) # Keep these intact
|
||||
final_parts.append(part)
|
||||
else:
|
||||
final_parts.append(part.replace('_', '.'))
|
||||
new_key = '.'.join(final_parts)
|
||||
@@ -274,6 +277,20 @@ def standardize_lora_key_format(lora_sd):
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
def compensate_rs_lora_format(lora_sd):
|
||||
rank = lora_sd["base_model.model.blocks.0.cross_attn.k.lora_A.weight"].shape[0]
|
||||
alpha = torch.tensor(rank * rank // rank ** 0.5)
|
||||
log.info(f"Detected rank stabilized peft lora format with rank {rank}, setting alpha to {alpha} to compensate.")
|
||||
new_sd = {}
|
||||
for k, v in lora_sd.items():
|
||||
if k.endswith(".lora_A.weight"):
|
||||
new_sd[k] = v
|
||||
new_k = k.replace(".lora_A.weight", ".alpha")
|
||||
new_sd[new_k] = alpha
|
||||
else:
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
class WanVideoBlockSwap:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -364,7 +381,7 @@ class WanVideoLoraSelect:
|
||||
"required": {
|
||||
"lora": (folder_paths.get_filename_list("loras"),
|
||||
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -1000.0, "max": 1000.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
|
||||
@@ -756,6 +773,8 @@ class WanVideoSetLoRAs:
|
||||
lora_sd = load_torch_file(lora_path, safe_load=True)
|
||||
if "dwpose_embedding.0.weight" in lora_sd: #unianimate
|
||||
raise NotImplementedError("Unianimate LoRA patching is not implemented in this node.")
|
||||
if "base_model.model.blocks.0.cross_attn.k.lora_A.weight" in lora_sd: # assume rs_lora
|
||||
lora_sd = compensate_rs_lora_format(lora_sd)
|
||||
|
||||
lora_sd = standardize_lora_key_format(lora_sd)
|
||||
if l["blocks"]:
|
||||
@@ -967,7 +986,8 @@ def add_lora_weights(patcher, lora, base_dtype, merge_loras=False):
|
||||
from .unianimate.nodes import update_transformer
|
||||
log.info("Unianimate LoRA detected, patching model...")
|
||||
patcher.model.diffusion_model, unianimate_sd = update_transformer(patcher.model.diffusion_model, lora_sd)
|
||||
|
||||
if "base_model.model.blocks.0.cross_attn.k.lora_A.weight" in lora_sd: # assume rs_lora
|
||||
lora_sd = compensate_rs_lora_format(lora_sd)
|
||||
lora_sd = standardize_lora_key_format(lora_sd)
|
||||
|
||||
if l["blocks"]:
|
||||
@@ -989,6 +1009,43 @@ def add_lora_weights(patcher, lora, base_dtype, merge_loras=False):
|
||||
del lora_sd
|
||||
return patcher, control_lora, unianimate_sd
|
||||
|
||||
class WanVideoSetAttentionModeOverride:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("WANVIDEOMODEL", ),
|
||||
"attention_mode": (attention_modes, {"default": "sdpa"}),
|
||||
"start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Step to start applying the attention mode override"}),
|
||||
"end_step": ("INT", {"default": 10000, "min": 1, "max": 10000, "step": 1, "tooltip": "Step to end applying the attention mode override"}),
|
||||
"verbose": ("BOOLEAN", {"default": False, "tooltip": "Print verbose info about attention mode override during generation"}),
|
||||
},
|
||||
"optional": {
|
||||
"blocks":("INT", {"forceInput": True} ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDEOMODEL",)
|
||||
RETURN_NAMES = ("model", )
|
||||
FUNCTION = "getmodelpath"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
DESCRIPTION = "Override the attention mode for the model for specific step and/or block range"
|
||||
|
||||
def getmodelpath(self, model, attention_mode, start_step, end_step, verbose, blocks=None):
|
||||
model_clone = model.clone()
|
||||
attention_mode_override = {
|
||||
"mode": attention_mode,
|
||||
"start_step": start_step,
|
||||
"end_step": end_step,
|
||||
"verbose": verbose,
|
||||
}
|
||||
if blocks is not None:
|
||||
attention_mode_override["blocks"] = blocks
|
||||
model_clone.model_options['transformer_options']["attention_mode_override"] = attention_mode_override
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
|
||||
#region Model loading
|
||||
class WanVideoModelLoader:
|
||||
@classmethod
|
||||
@@ -1003,17 +1060,7 @@ class WanVideoModelLoader:
|
||||
"load_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
|
||||
},
|
||||
"optional": {
|
||||
"attention_mode": ([
|
||||
"sdpa",
|
||||
"flash_attn_2",
|
||||
"flash_attn_3",
|
||||
"sageattn",
|
||||
"sageattn_3",
|
||||
"radial_sage_attention",
|
||||
"sageattn_compiled",
|
||||
"sageattn_ultravico",
|
||||
"comfy"
|
||||
], {"default": "sdpa"}),
|
||||
"attention_mode": (attention_modes, {"default": "sdpa"}),
|
||||
"compile_args": ("WANCOMPILEARGS", ),
|
||||
"block_swap_args": ("BLOCKSWAPARGS", ),
|
||||
"lora": ("WANVIDLORA", {"default": None}),
|
||||
@@ -1218,9 +1265,7 @@ class WanVideoModelLoader:
|
||||
lynx_ip_layers = "lite"
|
||||
|
||||
model_type = "t2v"
|
||||
if "audio_injector.injector.0.k.weight" in sd:
|
||||
model_type = "s2v"
|
||||
elif not "text_embedding.0.weight" in sd:
|
||||
if not "text_embedding.0.weight" in sd:
|
||||
model_type = "no_cross_attn" #minimaxremover
|
||||
elif "model_type.Wan2_1-FLF2V-14B-720P" in sd or "img_emb.emb_pos" in sd or "flf2v" in model.lower():
|
||||
model_type = "fl2v"
|
||||
@@ -1230,6 +1275,8 @@ class WanVideoModelLoader:
|
||||
model_type = "t2v"
|
||||
elif "control_adapter.conv.weight" in sd:
|
||||
model_type = "t2v"
|
||||
if "audio_injector.injector.0.k.weight" in sd:
|
||||
model_type = "s2v"
|
||||
|
||||
out_dim = 16
|
||||
if dim == 5120: #14B
|
||||
@@ -2044,6 +2091,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoTorchCompileSettings": WanVideoTorchCompileSettings,
|
||||
"LoadWanVideoT5TextEncoder": LoadWanVideoT5TextEncoder,
|
||||
"LoadWanVideoClipTextEncoder": LoadWanVideoClipTextEncoder,
|
||||
"WanVideoSetAttentionModeOverride": WanVideoSetAttentionModeOverride,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -2062,4 +2110,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoTorchCompileSettings": "WanVideo Torch Compile Settings",
|
||||
"LoadWanVideoT5TextEncoder": "WanVideo T5 Text Encoder Loader",
|
||||
"LoadWanVideoClipTextEncoder": "WanVideo CLIP Text Encoder Loader",
|
||||
"WanVideoSetAttentionModeOverride": "WanVideo Set Attention Mode Override",
|
||||
}
|
||||
|
||||
963
nodes_sampler.py
963
nodes_sampler.py
File diff suppressed because it is too large
Load Diff
@@ -38,7 +38,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, current_flag,
|
||||
|
||||
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
||||
|
||||
window_th = 1560 * 21 / 2
|
||||
window_th = frame_tokens * window_width / 2
|
||||
dist2 = tl.abs(m - n).to(tl.int32)
|
||||
dist_mask = dist2 <= window_th
|
||||
|
||||
@@ -46,7 +46,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, current_flag,
|
||||
|
||||
qk = tl.where(dist_mask | negative_mask, qk, qk*multi_factor)
|
||||
|
||||
window3 = (m <= frame_tokens) & (n > 21*frame_tokens)
|
||||
window3 = (m <= frame_tokens) & (n > window_width*frame_tokens)
|
||||
qk = tl.where(window3, -1e4, qk)
|
||||
|
||||
|
||||
|
||||
88
utils.py
88
utils.py
@@ -4,17 +4,99 @@ import logging
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gc
|
||||
import types, collections
|
||||
from comfy.utils import ProgressBar, copy_to_param, set_attr_param
|
||||
from comfy.model_patcher import get_key_weight, string_to_seed
|
||||
from comfy.lora import calculate_weight
|
||||
from comfy.model_management import cast_to_device
|
||||
|
||||
from comfy.float import stochastic_rounding
|
||||
from .custom_linear import remove_lora_from_module
|
||||
import folder_paths
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
import comfy.model_management as mm
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
try:
|
||||
from .gguf.gguf import GGUFParameter
|
||||
except:
|
||||
pass
|
||||
|
||||
class MetaParameter(torch.nn.Parameter):
|
||||
def __new__(cls, dtype, quant_type=None):
|
||||
data = torch.empty(0, dtype=dtype)
|
||||
self = torch.nn.Parameter(data, requires_grad=False)
|
||||
self.quant_type = quant_type
|
||||
return self
|
||||
|
||||
def offload_transformer(transformer, remove_lora=True):
|
||||
transformer.teacache_state.clear_all()
|
||||
transformer.magcache_state.clear_all()
|
||||
transformer.easycache_state.clear_all()
|
||||
|
||||
if transformer.patched_linear:
|
||||
for name, param in transformer.named_parameters():
|
||||
if "loras" in name or "controlnet" in name:
|
||||
continue
|
||||
module = transformer
|
||||
subnames = name.split('.')
|
||||
for subname in subnames[:-1]:
|
||||
module = getattr(module, subname)
|
||||
attr_name = subnames[-1]
|
||||
if param.data.is_floating_point():
|
||||
meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False)
|
||||
setattr(module, attr_name, meta_param)
|
||||
elif isinstance(param.data, GGUFParameter):
|
||||
quant_type = getattr(param, 'quant_type', None)
|
||||
setattr(module, attr_name, MetaParameter(param.data.dtype, quant_type))
|
||||
else:
|
||||
pass
|
||||
if remove_lora:
|
||||
remove_lora_from_module(transformer)
|
||||
else:
|
||||
transformer.to(offload_device)
|
||||
|
||||
for block in transformer.blocks:
|
||||
block.kv_cache = None
|
||||
if transformer.audio_model is not None and hasattr(block, 'audio_block'):
|
||||
block.audio_block = None
|
||||
|
||||
mm.soft_empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def init_blockswap(transformer, block_swap_args, model):
|
||||
if not transformer.patched_linear:
|
||||
if block_swap_args is not None:
|
||||
for name, param in transformer.named_parameters():
|
||||
if "block" not in name or "control_adapter" in name or "face" in name:
|
||||
param.data = param.data.to(device)
|
||||
elif block_swap_args["offload_txt_emb"] and "txt_emb" in name:
|
||||
param.data = param.data.to(offload_device)
|
||||
elif block_swap_args["offload_img_emb"] and "img_emb" in name:
|
||||
param.data = param.data.to(offload_device)
|
||||
|
||||
transformer.block_swap(
|
||||
block_swap_args["blocks_to_swap"] - 1 ,
|
||||
block_swap_args["offload_txt_emb"],
|
||||
block_swap_args["offload_img_emb"],
|
||||
vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None),
|
||||
)
|
||||
elif model["auto_cpu_offload"]:
|
||||
for module in transformer.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
for block in transformer.blocks:
|
||||
block.modulation = torch.nn.Parameter(block.modulation.to(device))
|
||||
transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device))
|
||||
else:
|
||||
transformer.to(device)
|
||||
|
||||
def check_device_same(first_device, second_device):
|
||||
if first_device.type != second_device.type:
|
||||
return False
|
||||
@@ -140,7 +222,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, back
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
temp_weight = mm.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
if convert_func is not None:
|
||||
|
||||
@@ -80,9 +80,9 @@ except:
|
||||
try:
|
||||
from ...ultravico.sageattn.core import sage_attention as sageattn_ultravico
|
||||
@torch.library.custom_op("wanvideo::sageattn_ultravico", mutates_args=())
|
||||
def sageattn_func_ultravico(qkv: List[torch.Tensor], attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, multi_factor: float = 0.9
|
||||
def sageattn_func_ultravico(qkv: List[torch.Tensor], attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, multi_factor: float = 0.9, frame_tokens: int = 1536
|
||||
) -> torch.Tensor:
|
||||
return sageattn_ultravico(qkv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, multi_factor=multi_factor)
|
||||
return sageattn_ultravico(qkv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, multi_factor=multi_factor, frame_tokens=frame_tokens)
|
||||
|
||||
@sageattn_func_ultravico.register_fake
|
||||
def _(qkv, attn_mask=None, dropout_p=0.0, is_causal=False, multi_factor=0.9):
|
||||
@@ -94,7 +94,7 @@ except:
|
||||
|
||||
def attention(q, k, v, q_lens=None, k_lens=None, max_seqlen_q=None, max_seqlen_k=None, dropout_p=0.,
|
||||
softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16,
|
||||
attention_mode='sdpa', attn_mask=None, multi_factor=0.9, heads=128):
|
||||
attention_mode='sdpa', attn_mask=None, multi_factor=0.9, frame_tokens=1536, heads=128):
|
||||
if "flash" in attention_mode:
|
||||
return flash_attention(q, k, v, q_lens=q_lens, k_lens=k_lens, dropout_p=dropout_p, softmax_scale=softmax_scale,
|
||||
q_scale=q_scale, causal=causal, window_size=window_size, deterministic=deterministic, dtype=dtype, version=2 if attention_mode == 'flash_attn_2' else 3,
|
||||
@@ -108,7 +108,7 @@ def attention(q, k, v, q_lens=None, k_lens=None, max_seqlen_q=None, max_seqlen_k
|
||||
elif attention_mode == 'sageattn':
|
||||
return sageattn_func(q, k, v, tensor_layout="NHD").contiguous()
|
||||
elif attention_mode == 'sageattn_ultravico':
|
||||
return sageattn_func_ultravico([q, k, v], multi_factor=multi_factor).contiguous()
|
||||
return sageattn_func_ultravico([q, k, v], multi_factor=multi_factor, frame_tokens=frame_tokens).contiguous()
|
||||
elif attention_mode == 'comfy':
|
||||
return optimized_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), heads=heads, skip_reshape=True)
|
||||
else: # sdpa
|
||||
|
||||
@@ -467,7 +467,7 @@ class WanSelfAttention(nn.Module):
|
||||
v = (self.v(x) + self.v_loras(x)).view(b, s, n, d)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, q, k, v, seq_lens, lynx_ref_feature=None, lynx_ref_scale=1.0, attention_mode_override=None, onetoall_ref=None, onetoall_ref_scale=1.0):
|
||||
def forward(self, q, k, v, seq_lens, lynx_ref_feature=None, lynx_ref_scale=1.0, attention_mode_override=None, onetoall_ref=None, onetoall_ref_scale=1.0, frame_tokens=1536):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
@@ -482,7 +482,7 @@ class WanSelfAttention(nn.Module):
|
||||
if self.ref_adapter is not None and lynx_ref_feature is not None:
|
||||
ref_x = self.ref_adapter(self, q, lynx_ref_feature)
|
||||
|
||||
x = attention(q, k, v, k_lens=seq_lens, attention_mode=attention_mode, heads=self.num_heads)
|
||||
x = attention(q, k, v, k_lens=seq_lens, attention_mode=attention_mode, heads=self.num_heads, frame_tokens=frame_tokens)
|
||||
|
||||
if self.ref_adapter is not None and lynx_ref_feature is not None:
|
||||
x = x.add(ref_x, alpha=lynx_ref_scale)
|
||||
@@ -497,7 +497,7 @@ class WanSelfAttention(nn.Module):
|
||||
attention_mode = self.attention_mode
|
||||
if attention_mode_override is not None:
|
||||
attention_mode = attention_mode_override
|
||||
|
||||
|
||||
# Concatenate main and IP keys/values for main attention
|
||||
full_k = torch.cat([k, k_ip], dim=1)
|
||||
full_v = torch.cat([v, v_ip], dim=1)
|
||||
@@ -1006,6 +1006,7 @@ class WanAttentionBlock(nn.Module):
|
||||
longcat_num_cond_latents=0, longcat_avatar_options=None, #longcat image cond amount
|
||||
x_onetoall_ref=None, onetoall_freqs=None, onetoall_ref=None, onetoall_ref_scale=1.0, #one-to-all
|
||||
e_tr=None, tr_num=0, tr_start=0, #token replacement
|
||||
attention_mode_override=None, frame_tokens=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -1150,6 +1151,10 @@ class WanAttentionBlock(nn.Module):
|
||||
if enhance_enabled:
|
||||
feta_scores = get_feta_scores(q, k)
|
||||
|
||||
if self.attention_mode == "sageattn_3" and attention_mode_override is None:
|
||||
if current_step != 0 and not last_step:
|
||||
attention_mode_override = "sageattn"
|
||||
|
||||
#self-attention
|
||||
split_attn = (context is not None
|
||||
and (context.shape[0] > 1 or (clip_embed is not None and clip_embed.shape[0] > 1))
|
||||
@@ -1161,19 +1166,14 @@ class WanAttentionBlock(nn.Module):
|
||||
y = self.self_attn.forward_split(q, k, v, seq_lens, grid_sizes, seq_chunks)
|
||||
elif ref_target_masks is not None: #multi/infinite talk
|
||||
y, x_ref_attn_map = self.self_attn.forward_multitalk(q, k, v, seq_lens, grid_sizes, ref_target_masks)
|
||||
elif self.attention_mode == "radial_sage_attention":
|
||||
elif self.attention_mode == "radial_sage_attention" or attention_mode_override is not None and attention_mode_override == "radial_sage_attention":
|
||||
if self.dense_block or self.dense_timesteps is not None and current_step < self.dense_timesteps:
|
||||
if self.dense_attention_mode == "sparse_sage_attn":
|
||||
y = self.self_attn.forward_radial(q, k, v, dense_step=True)
|
||||
else:
|
||||
y = self.self_attn.forward(q, k, v, seq_lens)
|
||||
y = self.self_attn.forward(q, k, v, seq_lens, attention_mode_override=attention_mode_override)
|
||||
else:
|
||||
y = self.self_attn.forward_radial(q, k, v, dense_step=False)
|
||||
elif self.attention_mode == "sageattn_3":
|
||||
if current_step != 0 and not last_step:
|
||||
y = self.self_attn.forward(q, k, v, seq_lens, attention_mode_override="sageattn_3")
|
||||
else:
|
||||
y = self.self_attn.forward(q, k, v, seq_lens, attention_mode_override="sageattn")
|
||||
elif x_ip is not None and self.kv_cache is None: #stand-in
|
||||
# First pass: cache IP keys/values and compute attention
|
||||
self.kv_cache = {"k_ip": k_ip.detach(), "v_ip": v_ip.detach()}
|
||||
@@ -1184,18 +1184,18 @@ class WanAttentionBlock(nn.Module):
|
||||
v_ip = self.kv_cache["v_ip"]
|
||||
full_k = torch.cat([k, k_ip], dim=1)
|
||||
full_v = torch.cat([v, v_ip], dim=1)
|
||||
y = self.self_attn.forward(q, full_k, full_v, seq_lens)
|
||||
y = self.self_attn.forward(q, full_k, full_v, seq_lens, attention_mode_override=attention_mode_override)
|
||||
elif is_longcat and longcat_num_cond_latents > 0:
|
||||
if longcat_num_cond_latents == 1:
|
||||
num_cond_latents_thw = longcat_num_cond_latents * (N // num_latent_frames)
|
||||
# process the noise tokens
|
||||
x_noise = self.self_attn.forward(q[:, num_cond_latents_thw:].contiguous(), k, v, seq_lens)
|
||||
x_noise = self.self_attn.forward(q[:, num_cond_latents_thw:].contiguous(), k, v, seq_lens, attention_mode_override=attention_mode_override)
|
||||
# process the condition tokens
|
||||
x_cond = self.self_attn.forward(
|
||||
q[:, :num_cond_latents_thw].contiguous(),
|
||||
k[:, :num_cond_latents_thw].contiguous(),
|
||||
v[:, :num_cond_latents_thw].contiguous(),
|
||||
seq_lens)
|
||||
seq_lens, attention_mode_override=attention_mode_override)
|
||||
# merge x_cond and x_noise
|
||||
y = torch.cat([x_cond, x_noise], dim=1).contiguous()
|
||||
elif longcat_num_cond_latents > 1: # video continuation
|
||||
@@ -1237,13 +1237,14 @@ class WanAttentionBlock(nn.Module):
|
||||
q_cond = q[:, num_ref_latents_thw:num_cond_latents_thw].contiguous()
|
||||
k_cond = k[:, num_ref_latents_thw:num_cond_latents_thw].contiguous()
|
||||
v_cond = v[:, num_ref_latents_thw:num_cond_latents_thw].contiguous()
|
||||
x_ref = self.self_attn.forward(q_ref, k_ref, v_ref, seq_lens)
|
||||
x_cond = self.self_attn.forward(q_cond, k_cond, v_cond, seq_lens)
|
||||
x_ref = self.self_attn.forward(q_ref, k_ref, v_ref, seq_lens, attention_mode_override=attention_mode_override)
|
||||
x_cond = self.self_attn.forward(q_cond, k_cond, v_cond, seq_lens, attention_mode_override=attention_mode_override)
|
||||
|
||||
# merge x_cond and x_noise
|
||||
y = torch.cat([x_ref, x_cond, x_noise], dim=1).contiguous()
|
||||
else:
|
||||
y = self.self_attn.forward(q, k, v, seq_lens, lynx_ref_feature=lynx_ref_feature, lynx_ref_scale=lynx_ref_scale, onetoall_ref=onetoall_ref, onetoall_ref_scale=onetoall_ref_scale)
|
||||
y = self.self_attn.forward(q, k, v, seq_lens, lynx_ref_feature=lynx_ref_feature, lynx_ref_scale=lynx_ref_scale,
|
||||
onetoall_ref=onetoall_ref, onetoall_ref_scale=onetoall_ref_scale, attention_mode_override=attention_mode_override, frame_tokens=frame_tokens)
|
||||
|
||||
del q, k, v
|
||||
|
||||
@@ -2187,7 +2188,7 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
def rope_encode_comfy(self, t, h, w, freq_offset=0, t_start=0, ref_frame_shape=None, pose_frame_shape=None,
|
||||
steps_t=None, steps_h=None, steps_w=None, ntk_alphas=[1,1,1], device=None, dtype=None,
|
||||
ref_frame_index=10, longcat_num_ref_latents=None):
|
||||
ref_frame_index=10, longcat_num_ref_latents=0):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
@@ -2280,6 +2281,7 @@ class WanModel(torch.nn.Module):
|
||||
self, x, t, context, seq_len,
|
||||
is_uncond=False,
|
||||
current_step_percentage=0.0, current_step=0, last_step=0, total_steps=50,
|
||||
attention_mode_override=None,
|
||||
clip_fea=None, y=None,
|
||||
device=torch.device('cuda'),
|
||||
freqs=None,
|
||||
@@ -2664,7 +2666,7 @@ class WanModel(torch.nn.Module):
|
||||
device=x.device,
|
||||
dtype=x.dtype
|
||||
)
|
||||
log.info("Generated new RoPE frequencies")
|
||||
tqdm.write("Generated new RoPE frequencies")
|
||||
|
||||
if s2v_ref_latent is not None:
|
||||
freqs_ref = self.rope_encode_comfy(
|
||||
@@ -3093,6 +3095,7 @@ class WanModel(torch.nn.Module):
|
||||
camera_embed=camera_embed,
|
||||
audio_proj=audio_proj,
|
||||
num_latent_frames = F,
|
||||
frame_tokens=x.shape[1] // F,
|
||||
original_seq_len=self.original_seq_len,
|
||||
enhance_enabled=enhance_enabled,
|
||||
audio_scale=audio_scale,
|
||||
@@ -3179,8 +3182,21 @@ class WanModel(torch.nn.Module):
|
||||
if lynx_ref_buffer is None and lynx_ref_feature_extractor:
|
||||
lynx_ref_buffer = {}
|
||||
|
||||
attn_override_blocks = attention_mode = None
|
||||
attention_mode_override_active = False
|
||||
if attention_mode_override is not None:
|
||||
attn_override_blocks = attention_mode_override.get("blocks", range(len(self.blocks)))
|
||||
if attention_mode_override["start_step"] <= current_step < attention_mode_override["end_step"]:
|
||||
attention_mode_override_active = True
|
||||
if attention_mode_override["verbose"]:
|
||||
tqdm.write(f"Applying attention mode override: {attention_mode_override['mode']} at step {current_step} on blocks: {attn_override_blocks if attn_override_blocks is not None else 'all'}")
|
||||
|
||||
for b, block in enumerate(self.blocks):
|
||||
mm.throw_exception_if_processing_interrupted()
|
||||
if attention_mode_override_active and b in attn_override_blocks:
|
||||
attention_mode = attention_mode_override['mode']
|
||||
else:
|
||||
attention_mode = None
|
||||
block_idx = f"{b:02d}"
|
||||
if lynx_ref_buffer is not None and not lynx_ref_feature_extractor:
|
||||
lynx_ref_feature = lynx_ref_buffer.get(block_idx, None)
|
||||
@@ -3224,7 +3240,7 @@ class WanModel(torch.nn.Module):
|
||||
x_onetoall_ref = onetoall_ref_block_samples[b // interval_ref]
|
||||
|
||||
# ---run block----#
|
||||
x, x_ip, lynx_ref_feature, x_ovi = block(x, x_ip=x_ip, lynx_ref_feature=lynx_ref_feature, x_ovi=x_ovi, x_onetoall_ref=x_onetoall_ref, onetoall_freqs=onetoall_freqs, **kwargs)
|
||||
x, x_ip, lynx_ref_feature, x_ovi = block(x, x_ip=x_ip, lynx_ref_feature=lynx_ref_feature, x_ovi=x_ovi, x_onetoall_ref=x_onetoall_ref, onetoall_freqs=onetoall_freqs, attention_mode_override=attention_mode, **kwargs)
|
||||
# ---post block----#
|
||||
|
||||
# dual controlnet
|
||||
|
||||
@@ -42,7 +42,7 @@ def _apply_custom_sigmas(sample_scheduler, sigmas, device):
|
||||
sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
|
||||
sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
|
||||
|
||||
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, enhance_hf=False, **kwargs):
|
||||
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, denoise_strength=1.0, sigmas=None, log_timesteps=False, enhance_hf=False, **kwargs):
|
||||
timesteps = None
|
||||
if sigmas is not None:
|
||||
steps = len(sigmas) - 1
|
||||
|
||||
Reference in New Issue
Block a user