1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Add ref_mask input

This commit is contained in:
kijai
2025-12-08 09:32:09 +02:00
parent 74cad232fd
commit 2f97b1bd88

View File

@@ -16,6 +16,9 @@ class WanVideoAddOneToAllReferenceEmbeds:
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the embedding application"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the embedding application"}),
},
"optional": {
"ref_mask": ("MASK",),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
@@ -23,14 +26,20 @@ class WanVideoAddOneToAllReferenceEmbeds:
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, embeds, vae, ref_image, strength, start_percent=0.0, end_percent=1.0):
def add(self, embeds, vae, ref_image, strength, start_percent, end_percent, ref_mask=None):
updated = dict(embeds)
ref_latent = ref_latent_empty = None
vae.to(device)
ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype)
ref_latent = vae.encode([ref_image_in], device, tiled=False)
ref_latent_empty = vae.encode([torch.zeros_like(ref_image_in)-1], device, tiled=False)
ref_mask_in = None
if ref_mask is not None:
print("ref_image_in shape:", ref_image_in.shape)
ref_mask_in = (ref_mask.unsqueeze(0).repeat(3, 1, 1, 1) * 2 - 1.).to(device, vae.dtype)
else:
ref_mask_in = torch.zeros_like(ref_image_in)-1
ref_latent_empty = vae.encode([ref_mask_in], device, tiled=False)
vae.to(offload_device)
updated.setdefault("one_to_all_embeds", {})