You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
Add v2 sampling nodes for cleaner workflows
No functional changes, just cleaner nodes
This commit is contained in:
155
nodes.py
155
nodes.py
@@ -1,11 +1,8 @@
|
||||
import os, gc, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import hashlib
|
||||
|
||||
from .wanvideo.schedulers import get_scheduler, scheduler_list
|
||||
|
||||
from .utils import(log, clip_encode_image_tiled, add_noise_to_reference_video, set_module_tensor_to_device)
|
||||
from .taehv import TAEHV
|
||||
|
||||
@@ -1927,155 +1924,6 @@ class WanVideoFreeInitArgs:
|
||||
|
||||
def process(self, **kwargs):
|
||||
return (kwargs,)
|
||||
|
||||
class WanVideoScheduler: #WIP
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"scheduler": (scheduler_list, {"default": "unipc"}),
|
||||
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
|
||||
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
|
||||
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
|
||||
},
|
||||
"optional": {
|
||||
"sigmas": ("SIGMAS", ),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
|
||||
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigmas=None):
|
||||
sample_scheduler, timesteps, start_idx, end_idx = get_scheduler(
|
||||
scheduler,
|
||||
steps,
|
||||
start_step, end_step, shift,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
log_timesteps=True)
|
||||
|
||||
scheduler_dict = {
|
||||
"sample_scheduler": sample_scheduler,
|
||||
"timesteps": timesteps,
|
||||
}
|
||||
|
||||
try:
|
||||
from server import PromptServer
|
||||
import io
|
||||
import base64
|
||||
import matplotlib.pyplot as plt
|
||||
except:
|
||||
PromptServer = None
|
||||
if unique_id and PromptServer is not None:
|
||||
try:
|
||||
# Plot sigmas and save to a buffer
|
||||
sigmas_np = sample_scheduler.full_sigmas.cpu().numpy()
|
||||
if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6):
|
||||
sigmas_np = np.append(sigmas_np, 0.0)
|
||||
buf = io.BytesIO()
|
||||
fig = plt.figure(facecolor='#353535')
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_facecolor('#353535') # Set axes background color
|
||||
x_values = range(0, len(sigmas_np))
|
||||
ax.plot(x_values, sigmas_np)
|
||||
# Annotate each sigma value
|
||||
ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3) # Small dots at each sigma
|
||||
for x, y in zip(x_values, sigmas_np):
|
||||
# Show all annotations if few steps, or just show split step annotations
|
||||
show_annotation = len(sigmas_np) <= 10
|
||||
is_split_step = (start_idx > 0 and x == start_idx) or (end_idx != -1 and x == end_idx + 1)
|
||||
|
||||
if show_annotation or is_split_step:
|
||||
color = 'orange'
|
||||
if is_split_step:
|
||||
color = 'yellow'
|
||||
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color=color, fontsize=12)
|
||||
ax.set_xticks(x_values)
|
||||
ax.set_title("Sigmas", color='white') # Title font color
|
||||
ax.set_xlabel("Step", color='white') # X label font color
|
||||
ax.set_ylabel("Sigma Value", color='white') # Y label font color
|
||||
ax.tick_params(axis='x', colors='white', labelsize=10) # X tick color
|
||||
ax.tick_params(axis='y', colors='white', labelsize=10) # Y tick color
|
||||
# Add split point if end_step is defined
|
||||
end_idx += 1
|
||||
if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1:
|
||||
ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split')
|
||||
# Add split point if start_step is defined
|
||||
if start_idx > 0 and 0 <= start_idx < len(sigmas_np):
|
||||
ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split')
|
||||
if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)):
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
if labels:
|
||||
ax.legend()
|
||||
if start_idx < end_idx and 0 <= start_idx < len(sigmas_np) and 0 < end_idx < len(sigmas_np):
|
||||
ax.axvspan(start_idx, end_idx, color='lightblue', alpha=0.1, label='Sampled Range')
|
||||
plt.tight_layout()
|
||||
plt.savefig(buf, format='png')
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
buf.close()
|
||||
|
||||
# Send as HTML img tag with base64 data
|
||||
html_img = f"<img src='data:image/png;base64,{img_base64}' alt='Sigmas Plot' style='max-width:100%; height:100%; overflow:hidden; display:block;'>"
|
||||
PromptServer.instance.send_progress_text(html_img, unique_id)
|
||||
except Exception as e:
|
||||
print("Failed to send sigmas plot:", e)
|
||||
pass
|
||||
|
||||
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
|
||||
|
||||
class WanVideoSchedulerSA_ODE:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"use_adaptive_order": ("BOOLEAN", {"default": False, "tooltip": "Use adaptive order"}),
|
||||
"use_velocity_smoothing": ("BOOLEAN", {"default": True, "tooltip": "Use velocity smoothing"}),
|
||||
"convergence_threshold": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Convergence threshold for velocity smoothing"}),
|
||||
"smoothing_factor": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Smoothing factor for velocity smoothing"}),
|
||||
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
|
||||
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
|
||||
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
|
||||
},
|
||||
"optional": {
|
||||
"sigmas": ("SIGMAS", ),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
|
||||
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, steps, start_step, end_step, shift, use_adaptive_order, use_velocity_smoothing, convergence_threshold, smoothing_factor, sigmas=None):
|
||||
sample_scheduler, timesteps, _, _ = get_scheduler(
|
||||
scheduler="sa_ode_stable/lowstep",
|
||||
steps=steps,
|
||||
start_step=start_step, end_step=end_step, shift=shift,
|
||||
device=device,
|
||||
sigmas=sigmas,
|
||||
log_timesteps=True,
|
||||
use_adaptive_order=use_adaptive_order,
|
||||
use_velocity_smoothing=use_velocity_smoothing,
|
||||
convergence_threshold=convergence_threshold,
|
||||
smoothing_factor=smoothing_factor
|
||||
)
|
||||
|
||||
scheduler_dict = {
|
||||
"sample_scheduler": sample_scheduler,
|
||||
"timesteps": timesteps,
|
||||
}
|
||||
|
||||
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
|
||||
|
||||
rope_functions = ["default", "comfy", "comfy_chunked"]
|
||||
class WanVideoRoPEFunction:
|
||||
@@ -2396,7 +2244,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoBlockList": WanVideoBlockList,
|
||||
"WanVideoTextEncodeCached": WanVideoTextEncodeCached,
|
||||
"WanVideoAddExtraLatent": WanVideoAddExtraLatent,
|
||||
"WanVideoScheduler": WanVideoScheduler,
|
||||
"WanVideoAddStandInLatent": WanVideoAddStandInLatent,
|
||||
"WanVideoAddControlEmbeds": WanVideoAddControlEmbeds,
|
||||
"WanVideoAddMTVMotion": WanVideoAddMTVMotion,
|
||||
@@ -2404,7 +2251,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoAddPusaNoise": WanVideoAddPusaNoise,
|
||||
"WanVideoAnimateEmbeds": WanVideoAnimateEmbeds,
|
||||
"WanVideoAddLucyEditLatents": WanVideoAddLucyEditLatents,
|
||||
"WanVideoSchedulerSA_ODE": WanVideoSchedulerSA_ODE,
|
||||
"WanVideoAddBindweaveEmbeds": WanVideoAddBindweaveEmbeds,
|
||||
"TextImageEncodeQwenVL": TextImageEncodeQwenVL,
|
||||
"WanVideoUniLumosEmbeds": WanVideoUniLumosEmbeds,
|
||||
@@ -2447,7 +2293,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoAddPusaNoise": "WanVideo Add Pusa Noise",
|
||||
"WanVideoAnimateEmbeds": "WanVideo Animate Embeds",
|
||||
"WanVideoAddLucyEditLatents": "WanVideo Add LucyEdit Latents",
|
||||
"WanVideoSchedulerSA_ODE": "WanVideo Scheduler SA-ODE",
|
||||
"WanVideoAddBindweaveEmbeds": "WanVideo Add Bindweave Embeds",
|
||||
"WanVideoUniLumosEmbeds": "WanVideo UniLumos Embeds",
|
||||
"WanVideoAddTTMLatents": "WanVideo Add TTMLatents",
|
||||
|
||||
200
nodes_sampler.py
200
nodes_sampler.py
@@ -1055,6 +1055,8 @@ class WanVideoSampler:
|
||||
rope_function = "comfy" # only works with this currently
|
||||
|
||||
freqs = None
|
||||
|
||||
riflex_freq_index = 0 if riflex_freq_index is None else riflex_freq_index
|
||||
transformer.rope_embedder.k = None
|
||||
transformer.rope_embedder.num_frames = None
|
||||
d = transformer.dim // transformer.num_heads
|
||||
@@ -3274,13 +3276,211 @@ class WanVideoSamplerFromSettings(WanVideoSampler):
|
||||
def process(self, sampler_inputs):
|
||||
return super().process(**sampler_inputs)
|
||||
|
||||
|
||||
class WanVideoSamplerExtraArgs():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
},
|
||||
"optional": {
|
||||
"riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}),
|
||||
"feta_args": ("FETAARGS", ),
|
||||
"context_options": ("WANVIDCONTEXT", ),
|
||||
"cache_args": ("CACHEARGS", ),
|
||||
"slg_args": ("SLGARGS", ),
|
||||
"rope_function": (rope_functions, {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile. Chunked version has reduced peak VRAM usage when not using torch.compile"}),
|
||||
"loop_args": ("LOOPARGS", ),
|
||||
"experimental_args": ("EXPERIMENTALARGS", ),
|
||||
"unianimate_poses": ("UNIANIMATE_POSE", ),
|
||||
"fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ),
|
||||
"uni3c_embeds": ("UNI3C_EMBEDS", ),
|
||||
"multitalk_embeds": ("MULTITALK_EMBEDS", ),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("WANVIDSAMPLEREXTRAARGS",)
|
||||
RETURN_NAMES = ("extra_args", )
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def process(self, *args, **kwargs):
|
||||
return kwargs,
|
||||
|
||||
|
||||
class WanVideoSamplerv2(WanVideoSampler):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("WANVIDEOMODEL",),
|
||||
"image_embeds": ("WANVIDIMAGE_EMBEDS", ),
|
||||
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
|
||||
"scheduler": ("WANVIDEOSCHEDULER",),
|
||||
},
|
||||
"optional": {
|
||||
"text_embeds": ("WANVIDEOTEXTEMBEDS", ),
|
||||
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
|
||||
"add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}),
|
||||
"extra_args": ("WANVIDSAMPLEREXTRAARGS", ),
|
||||
}
|
||||
}
|
||||
|
||||
def process(self, *args, extra_args=None, **kwargs):
|
||||
import inspect
|
||||
params = inspect.signature(WanVideoSampler.process).parameters
|
||||
args_dict = {name: kwargs.get(name, param.default if param.default is not inspect.Parameter.empty else None)
|
||||
for name, param in params.items() if name != "self"}
|
||||
|
||||
if extra_args is not None:
|
||||
args_dict.update(extra_args)
|
||||
|
||||
return super().process(**args_dict)
|
||||
|
||||
|
||||
class WanVideoScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"scheduler": (scheduler_list, {"default": "unipc"}),
|
||||
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
|
||||
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
|
||||
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
|
||||
},
|
||||
"optional": {
|
||||
"sigmas": ("SIGMAS", ),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
|
||||
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigmas=None):
|
||||
sample_scheduler, timesteps, start_idx, end_idx = get_scheduler(
|
||||
scheduler, steps, start_step, end_step, shift, device, sigmas=sigmas, log_timesteps=True)
|
||||
|
||||
scheduler_dict = {
|
||||
"sample_scheduler": sample_scheduler,
|
||||
"timesteps": timesteps,
|
||||
}
|
||||
|
||||
try:
|
||||
from server import PromptServer
|
||||
import io
|
||||
import base64
|
||||
import matplotlib.pyplot as plt
|
||||
except:
|
||||
PromptServer = None
|
||||
if unique_id and PromptServer is not None:
|
||||
try:
|
||||
# Plot sigmas and save to a buffer
|
||||
sigmas_np = sample_scheduler.full_sigmas.cpu().numpy()
|
||||
if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6):
|
||||
sigmas_np = np.append(sigmas_np, 0.0)
|
||||
buf = io.BytesIO()
|
||||
fig = plt.figure(facecolor='#353535')
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_facecolor('#353535') # Set axes background color
|
||||
x_values = range(0, len(sigmas_np))
|
||||
ax.plot(x_values, sigmas_np)
|
||||
# Annotate each sigma value
|
||||
ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3) # Small dots at each sigma
|
||||
for x, y in zip(x_values, sigmas_np):
|
||||
# Show all annotations if few steps, or just show split step annotations
|
||||
show_annotation = len(sigmas_np) <= 10
|
||||
is_split_step = (start_idx > 0 and x == start_idx) or (end_idx != -1 and x == end_idx + 1)
|
||||
|
||||
if show_annotation or is_split_step:
|
||||
color = 'orange'
|
||||
if is_split_step:
|
||||
color = 'yellow'
|
||||
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color=color, fontsize=12)
|
||||
ax.set_xticks(x_values)
|
||||
ax.set_title("Sigmas", color='white') # Title font color
|
||||
ax.set_xlabel("Step", color='white') # X label font color
|
||||
ax.set_ylabel("Sigma Value", color='white') # Y label font color
|
||||
ax.tick_params(axis='x', colors='white', labelsize=10) # X tick color
|
||||
ax.tick_params(axis='y', colors='white', labelsize=10) # Y tick color
|
||||
# Add split point if end_step is defined
|
||||
end_idx += 1
|
||||
if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1:
|
||||
ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split')
|
||||
# Add split point if start_step is defined
|
||||
if start_idx > 0 and 0 <= start_idx < len(sigmas_np):
|
||||
ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split')
|
||||
if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)):
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
if labels:
|
||||
ax.legend()
|
||||
if start_idx < end_idx and 0 <= start_idx < len(sigmas_np) and 0 < end_idx < len(sigmas_np):
|
||||
ax.axvspan(start_idx, end_idx, color='lightblue', alpha=0.1, label='Sampled Range')
|
||||
plt.tight_layout()
|
||||
plt.savefig(buf, format='png')
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
buf.close()
|
||||
|
||||
# Send as HTML img tag with base64 data
|
||||
html_img = f"<img src='data:image/png;base64,{img_base64}' alt='Sigmas Plot' style='max-width:100%; height:100%; overflow:hidden; display:block;'>"
|
||||
PromptServer.instance.send_progress_text(html_img, unique_id)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to send sigmas plot: {e}")
|
||||
pass
|
||||
|
||||
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
|
||||
|
||||
class WanVideoSchedulerv2(WanVideoScheduler):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"scheduler": (scheduler_list, {"default": "unipc"}),
|
||||
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
|
||||
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
|
||||
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
|
||||
},
|
||||
"optional": {
|
||||
"sigmas": ("SIGMAS", ),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDEOSCHEDULER",)
|
||||
RETURN_NAMES = ("scheduler",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, *args, **kwargs):
|
||||
sigmas, steps, shift, scheduler_dict, start_step, end_step = super().process(*args, **kwargs)
|
||||
return scheduler_dict,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoSampler": WanVideoSampler,
|
||||
"WanVideoSamplerSettings": WanVideoSamplerSettings,
|
||||
"WanVideoSamplerFromSettings": WanVideoSamplerFromSettings,
|
||||
"WanVideoSamplerv2": WanVideoSamplerv2,
|
||||
"WanVideoSamplerExtraArgs": WanVideoSamplerExtraArgs,
|
||||
"WanVideoScheduler": WanVideoScheduler,
|
||||
"WanVideoSchedulerv2": WanVideoSchedulerv2,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoSampler": "WanVideo Sampler",
|
||||
"WanVideoSamplerSettings": "WanVideo Sampler Settings",
|
||||
"WanVideoSamplerFromSettings": "WanVideo Sampler From Settings",
|
||||
"WanVideoSamplerv2": "WanVideo Sampler v2",
|
||||
"WanVideoSamplerExtraArgs": "WanVideoSampler v2 Extra Args",
|
||||
"WanVideoScheduler": "WanVideo Scheduler",
|
||||
"WanVideoSchedulerv2": "WanVideo Scheduler v2",
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ def _apply_custom_sigmas(sample_scheduler, sigmas, device):
|
||||
|
||||
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, **kwargs):
|
||||
timesteps = None
|
||||
if sigmas is not None:
|
||||
steps = len(sigmas) - 1
|
||||
if scheduler == 'vibt_unipc':
|
||||
sample_scheduler = ViBTScheduler()
|
||||
sample_scheduler.set_parameters(shift=shift)
|
||||
|
||||
Reference in New Issue
Block a user