1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/shared_state.py
Vladimir Mandic f69b4c5589 disabling live preview should not disable progress updates
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-21 11:07:47 -04:00

299 lines
10 KiB
Python

import os
import re
import sys
import uuid
import time
import datetime
from modules.errors import log, display
debug_output = os.environ.get('SD_STATE_DEBUG', None)
debug_history = debug_output or os.environ.get('SD_STATE_HISTORY', None)
class State:
state_history = []
job_history = 0
task_history = 0
image_history = 0
latent_history = 0
id = 0
results = []
skipped = False
interrupted = False
paused = False
job = ""
job_no = 0
job_count = 0
batch_no = 0
batch_count = 0
frame_count = 0
total_jobs = 0
job_timestamp = '0'
_sampling_step = 0
sampling_steps = 0
current_latent = None
current_noise_pred = None
current_sigma = None
current_sigma_next = None
current_image = None
current_image_sampling_step = 0
id_live_preview = 0
textinfo = None
prediction_type = "epsilon"
api = False
disable_preview = False
preview_job = -1
time_start = None
duration = None
need_restart = False
server_start = time.time()
oom = False
def __init__(self):
log.debug(f'State initialized: id={id(self)}')
def __str__(self) -> str:
status = ' '
status += 'skipped ' if self.skipped else ''
status += 'interrupted ' if self.interrupted else ''
status += 'paused ' if self.paused else ''
status += 'restart ' if self.need_restart else ''
status += 'oom ' if self.oom else ''
status += 'api ' if self.api else ''
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}' # pylint: disable=protected-access
return f'State: ts={self.job_timestamp} job={self.job} jobs={self.job_no+1}/{self.job_count}/{self.total_jobs} step={self.sampling_step}/{self.sampling_steps} preview={self.preview_job}/{self.id_live_preview}/{self.current_image_sampling_step} status="{status.strip()}" fn={fn}'
@property
def sampling_step(self):
return self._sampling_step
@sampling_step.setter
def sampling_step(self, value):
self._sampling_step = value
if debug_output:
log.trace(f'State step: {self}')
def skip(self):
log.debug('State: skip requested')
self.skipped = True
def interrupt(self):
log.debug('State: interrupt requested')
self.interrupted = True
def pause(self):
self.paused = not self.paused
log.debug(f'State: {"pause" if self.paused else "continue"} requested')
def nextjob(self):
import modules.devices
self.do_set_current_image()
self.job_no += 1
# self.sampling_step = 0
self.current_image_sampling_step = 0
if debug_output:
log.trace(f'State next: {self}')
modules.devices.torch_gc()
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def status(self):
from modules import progress
from modules.api import models
res = models.ResStatus(
task=self.job,
current=progress.current_task or '',
id=self.id,
job=max(self.job_no, 0),
jobs=max(self.frame_count, self.job_count, self.job_no),
total=self.total_jobs,
timestamp=self.job_timestamp if self.job != '' else None,
step=self.sampling_step,
steps=self.sampling_steps,
queued=len(progress.pending_tasks),
status='unknown',
uptime = round(time.time() - self.server_start)
)
res.step = res.steps * res.job + res.step
res.steps = res.steps * res.jobs
res.progress = round(min(1, abs(res.step / res.steps) if res.steps > 0 else 0), 2)
res.elapsed = round(time.time() - self.time_start, 2) if self.time_start is not None else None
predicted = round(res.elapsed / res.progress, 2) if res.progress > 0 and res.elapsed is not None else None
res.eta = round(predicted - res.elapsed, 2) if predicted is not None else None
if self.paused:
res.status = 'paused'
elif self.interrupted:
res.status = 'interrupted'
elif self.skipped:
res.status = 'skipped'
else:
res.status = 'running' if self.job != '' else 'idle'
return res
def find(self, task_id:str):
for job in reversed(self.state_history):
if job['id'] == task_id:
return job
return None
def history(self, op:str, task_id:str=None, results:list=[]):
job = {
'id': task_id or self.id,
'job': self.job.lower(),
'op': op.lower(),
'timestamp': self.time_start,
'duration': self.duration,
'outputs': results,
}
self.state_history.append(job)
l = len(self.state_history)
if l > 10000:
del self.state_history[0]
if debug_history:
log.trace(f'State history: jobs={l} {job}')
def outputs(self, results):
if isinstance(results, list):
self.results += results
else:
self.results.append(results)
if len(self.results) > 0:
self.history('output', self.id, results=self.results)
def get_id(self, task_id:str=None):
if task_id is None or task_id == 0:
task_id = uuid.uuid4().hex[:15]
if not isinstance(task_id, str):
task_id = str(task_id)
match = re.search(r'\((.*?)\)', task_id)
return match.group(1) if match else task_id
def clear(self):
self.id = ''
self.job = ''
self.job_count = 0
self.job_no = 0
self.frame_count = 0
self.preview_job = -1
self.duration = None
self.paused = False
self.results = []
def begin(self, title="", task_id=0, api=None):
import modules.devices
self.clear()
self.interrupted = self.interrupted if title.startswith('Save') else False
self.skipped = False
self.job_history += 1
self.total_jobs += 1
self.current_image = None
self.current_image_sampling_step = 0
self.current_latent = None
self.current_noise_pred = None
self.current_sigma = None
self.current_sigma_next = None
self.id_live_preview = 0
self.id = self.get_id(task_id)
self.job = title
self.job_count = 1 # cannot be less than 1 on new job
self.batch_no = 0
self.batch_count = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self._sampling_step = 0
self.sampling_steps = 0
self.textinfo = None
self.prediction_type = "epsilon"
self.api = api or self.api
self.time_start = time.time()
self.history('begin', self.id)
if debug_output:
log.trace(f'State begin: {self}')
modules.devices.torch_gc()
return self.id
def end(self, task_id=None):
import modules.devices
if debug_output:
log.trace(f'State end: {self}')
if task_id is not None:
prev_job = self.find(task_id)
if prev_job is not None:
self.id = prev_job['id']
self.job = prev_job['job']
self.duration = round(time.time() - prev_job['timestamp'], 3) if prev_job['timestamp'] is not None else None
self.time_start = time.time()
self.history('end', task_id or self.id)
self.clear()
modules.devices.torch_gc()
def step(self, step:int=1):
self.sampling_step += step
def update(self, job:str, steps:int=0, jobs:int=0):
self.task_history += 1
# self._sampling_step = 0
if job == 'Ignore':
return
elif job == 'Grid':
self.sampling_steps = steps
self.job_count = jobs
else:
self.sampling_steps += (steps * jobs)
self.job_count += jobs
# self.job = job
if debug_output:
log.trace(f'State update: {self} steps={steps} jobs={jobs}')
def set_current_image(self):
if self.job == 'VAE' or self.job == 'Upscale': # avoid generating preview while vae is running
return False
from modules.shared import opts, cmd_opts
if cmd_opts.lowvram or self.api or (opts.show_progress_every_n_steps <= 0):
return False
if (not self.disable_preview) and (abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps):
return self.do_set_current_image()
return False
def do_set_current_image(self):
if (self.current_latent is None) or self.disable_preview or (self.preview_job == self.job_no):
return False
from modules import shared, sd_samplers
self.preview_job = self.job_no
try:
sample = self.current_latent
self.current_image_sampling_step = self.sampling_step
try:
if self.current_noise_pred is not None and self.current_sigma is not None and self.current_sigma_next is not None:
original_sample = sample - (self.current_noise_pred * (self.current_sigma_next-self.current_sigma))
if self.prediction_type in {"epsilon", "flow_prediction"}:
sample = original_sample - (self.current_noise_pred * self.current_sigma)
elif self.prediction_type == "v_prediction":
sample = self.current_noise_pred * (-self.current_sigma / (self.current_sigma**2 + 1) ** 0.5) + (original_sample / (self.current_sigma**2 + 1)) # pylint: disable=invalid-unary-operand-type
except Exception:
pass # ignore sigma errors
image = sd_samplers.samples_to_image_grid(sample) if shared.opts.show_progress_grid else sd_samplers.sample_to_image(sample)
self.assign_current_image(image)
self.preview_job = -1
return True
except Exception as e:
self.preview_job = -1
log.error(f'State image: last={self.id_live_preview} step={self.sampling_step} {e}')
display(e, 'State image')
return False
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1