1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/ras/ras_manager.py
Vladimir Mandic 6cf445d317 add ras-sd35 experimental
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-02-18 18:47:42 -05:00

105 lines
4.9 KiB
Python

class ras_manager:
def __init__(self):
## configurable
self.metric = "std"
self.patch_size = 2
self.scheduler_start_step = 4
self.sample_ratio = 0.5
self.starvation_scale = 1.0
self.vae_size = 8
self.high_ratio = 0.3
self.skip_num_step = []
self.skip_num_step_length = 0
# applied by sdnext pipeline in ras/__init__.py
self.scheduler_end_step = 0
self.error_reset_steps = [0, 0]
self.num_steps = 0
self.height = 0
self.width = 0
## dynamic
self.current_step = 0
self.is_RAS_step = False
self.is_next_RAS_step = False
self.cached_index = None
self.other_index = None
self.cached_patchified_index = None
self.other_patchified_index = None
self.image_rotary_emb_skip = None
self.cached_scaled_noise = None
self.skip_token_num_list = []
def __str__(self):
return f'steps={self.num_steps} start={self.scheduler_start_step} end={self.scheduler_end_step} patch={self.patch_size} metric={self.metric} reset={self.error_reset_steps} ratio={self.sample_ratio} starvation={self.starvation_scale} vae={self.vae_size} high={self.high_ratio} skip={self.skip_num_step}length={self.skip_num_step_length}'
def set_parameters(self, args):
self.patch_size = args.patch_size
self.scheduler_start_step = args.scheduler_start_step
self.scheduler_end_step = args.scheduler_end_step
self.metric = args.metric
self.error_reset_steps = [int(i.strip()) for i in args.error_reset_steps.split(",")]
self.sample_ratio = args.sample_ratio
self.num_steps = args.num_inference_steps
self.skip_num_step = args.skip_num_step
self.skip_num_step_length = args.skip_num_step_length
self.height = args.height
self.width = args.width
self.high_ratio = args.high_ratio
self.generate_skip_token_list()
def generate_skip_token_list(self):
avg_skip_token_num = int((1 - self.sample_ratio) * ((self.height // self.patch_size) // self.vae_size) * ((self.width // self.patch_size) // self.vae_size))
if self.skip_num_step_length == 0: # static dropping
self.skip_token_num_list = [avg_skip_token_num for i in range(self.num_steps)]
for i in self.error_reset_steps:
self.skip_token_num_list[i] = 0
for i in range(self.scheduler_start_step):
self.skip_token_num_list[i] = 0
return
for i in range(0, self.num_steps // self.skip_num_step_length + 1):
for j in range(self.skip_num_step_length):
if i * self.skip_num_step_length + j >= self.num_steps:
break
temp_skip_num = avg_skip_token_num + self.skip_num_step * (i - (((self.num_steps + self.scheduler_start_step) // self.skip_num_step_length) // 2))
temp_skip_num = (temp_skip_num // 64) * 64
self.skip_token_num_list.append(temp_skip_num)
for i in range(self.scheduler_start_step):
self.skip_token_num_list[i] = 0
for i in self.error_reset_steps:
self.skip_token_num_list[i] = 0
for i in range(len(self.skip_token_num_list)):
assert self.skip_token_num_list[i] >= 0, "Skip token number should be positive"
assert self.skip_token_num_list[i] <= ((self.height // self.patch_size) // self.vae_size) * ((self.width // self.patch_size) // self.vae_size)
def reset_cache(self):
self.cached_index = None
self.other_index = None
self.cached_patchified_index = None
self.other_patchified_index = None
self.image_rotary_emb_skip = None
self.cached_scaled_noise = None
self.current_step = 0
if self.current_step >= self.scheduler_start_step and self.current_step <= self.scheduler_end_step and self.current_step not in self.error_reset_steps:
self.is_RAS_step = True
else:
self.is_RAS_step = False
if self.current_step + 1 >= self.scheduler_start_step and self.current_step + 1 <= self.scheduler_end_step and self.current_step + 1 not in self.error_reset_steps:
self.is_next_RAS_step = True
else:
self.is_next_RAS_step = False
def increase_step(self):
self.current_step += 1
if self.current_step >= self.scheduler_start_step and self.current_step <= self.scheduler_end_step and self.current_step not in self.error_reset_steps:
self.is_RAS_step = True
else:
self.is_RAS_step = False
if self.current_step + 1 >= self.scheduler_start_step and self.current_step + 1 < self.scheduler_end_step and self.current_step + 1 not in self.error_reset_steps:
self.is_next_RAS_step = True
else:
self.is_next_RAS_step = False
MANAGER = ras_manager()