import os import re import sys import uuid import time import datetime from modules.errors import log, display debug_output = os.environ.get('SD_STATE_DEBUG', None) debug_history = debug_output or os.environ.get('SD_STATE_HISTORY', None) class State: state_history = [] job_history = 0 task_history = 0 image_history = 0 latent_history = 0 id = 0 results = [] skipped = False interrupted = False paused = False job = "" job_no = 0 job_count = 0 batch_no = 0 batch_count = 0 frame_count = 0 total_jobs = 0 job_timestamp = '0' _sampling_step = 0 sampling_steps = 0 current_latent = None current_noise_pred = None current_sigma = None current_sigma_next = None current_image = None current_image_sampling_step = 0 id_live_preview = 0 textinfo = None prediction_type = "epsilon" api = False disable_preview = False preview_job = -1 time_start = None duration = None need_restart = False server_start = time.time() oom = False def __init__(self): log.debug(f'State initialized: id={id(self)}') def __str__(self) -> str: status = ' ' status += 'skipped ' if self.skipped else '' status += 'interrupted ' if self.interrupted else '' status += 'paused ' if self.paused else '' status += 'restart ' if self.need_restart else '' status += 'oom ' if self.oom else '' status += 'api ' if self.api else '' fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}' # pylint: disable=protected-access return f'State: ts={self.job_timestamp} job={self.job} jobs={self.job_no+1}/{self.job_count}/{self.total_jobs} step={self.sampling_step}/{self.sampling_steps} preview={self.preview_job}/{self.id_live_preview}/{self.current_image_sampling_step} status="{status.strip()}" fn={fn}' @property def sampling_step(self): return self._sampling_step @sampling_step.setter def sampling_step(self, value): self._sampling_step = value if debug_output: log.trace(f'State step: {self}') def skip(self): log.debug('State: skip requested') self.skipped = True def interrupt(self): log.debug('State: interrupt requested') self.interrupted = True def pause(self): self.paused = not self.paused log.debug(f'State: {"pause" if self.paused else "continue"} requested') def nextjob(self): import modules.devices self.do_set_current_image() self.job_no += 1 # self.sampling_step = 0 self.current_image_sampling_step = 0 if debug_output: log.trace(f'State next: {self}') modules.devices.torch_gc() def dict(self): obj = { "skipped": self.skipped, "interrupted": self.interrupted, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, "job_no": self.job_no, "sampling_step": self.sampling_step, "sampling_steps": self.sampling_steps, } return obj def status(self): from modules import progress from modules.api import models res = models.ResStatus( task=self.job, current=progress.current_task or '', id=self.id, job=max(self.job_no, 0), jobs=max(self.frame_count, self.job_count, self.job_no), total=self.total_jobs, timestamp=self.job_timestamp if self.job != '' else None, step=self.sampling_step, steps=self.sampling_steps, queued=len(progress.pending_tasks), status='unknown', uptime = round(time.time() - self.server_start) ) res.step = res.steps * res.job + res.step res.steps = res.steps * res.jobs res.progress = round(min(1, abs(res.step / res.steps) if res.steps > 0 else 0), 2) res.elapsed = round(time.time() - self.time_start, 2) if self.time_start is not None else None predicted = round(res.elapsed / res.progress, 2) if res.progress > 0 and res.elapsed is not None else None res.eta = round(predicted - res.elapsed, 2) if predicted is not None else None if self.paused: res.status = 'paused' elif self.interrupted: res.status = 'interrupted' elif self.skipped: res.status = 'skipped' else: res.status = 'running' if self.job != '' else 'idle' return res def find(self, task_id:str): for job in reversed(self.state_history): if job['id'] == task_id: return job return None def history(self, op:str, task_id:str=None, results:list=[]): job = { 'id': task_id or self.id, 'job': self.job.lower(), 'op': op.lower(), 'timestamp': self.time_start, 'duration': self.duration, 'outputs': results, } self.state_history.append(job) l = len(self.state_history) if l > 10000: del self.state_history[0] if debug_history: log.trace(f'State history: jobs={l} {job}') def outputs(self, results): if isinstance(results, list): self.results += results else: self.results.append(results) if len(self.results) > 0: self.history('output', self.id, results=self.results) def get_id(self, task_id:str=None): if task_id is None or task_id == 0: task_id = uuid.uuid4().hex[:15] if not isinstance(task_id, str): task_id = str(task_id) match = re.search(r'\((.*?)\)', task_id) return match.group(1) if match else task_id def clear(self): self.id = '' self.job = '' self.job_count = 0 self.job_no = 0 self.frame_count = 0 self.preview_job = -1 self.duration = None self.paused = False self.results = [] def begin(self, title="", task_id=0, api=None): import modules.devices self.clear() self.interrupted = self.interrupted if title.startswith('Save') else False self.skipped = False self.job_history += 1 self.total_jobs += 1 self.current_image = None self.current_image_sampling_step = 0 self.current_latent = None self.current_noise_pred = None self.current_sigma = None self.current_sigma_next = None self.id_live_preview = 0 self.id = self.get_id(task_id) self.job = title self.job_count = 1 # cannot be less than 1 on new job self.batch_no = 0 self.batch_count = 0 self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") self._sampling_step = 0 self.sampling_steps = 0 self.textinfo = None self.prediction_type = "epsilon" self.api = api or self.api self.time_start = time.time() self.history('begin', self.id) if debug_output: log.trace(f'State begin: {self}') modules.devices.torch_gc() return self.id def end(self, task_id=None): import modules.devices if debug_output: log.trace(f'State end: {self}') if task_id is not None: prev_job = self.find(task_id) if prev_job is not None: self.id = prev_job['id'] self.job = prev_job['job'] self.duration = round(time.time() - prev_job['timestamp'], 3) if prev_job['timestamp'] is not None else None self.time_start = time.time() self.history('end', task_id or self.id) self.clear() modules.devices.torch_gc() def step(self, step:int=1): self.sampling_step += step def update(self, job:str, steps:int=0, jobs:int=0): self.task_history += 1 # self._sampling_step = 0 if job == 'Ignore': return elif job == 'Grid': self.sampling_steps = steps self.job_count = jobs else: self.sampling_steps += (steps * jobs) self.job_count += jobs # self.job = job if debug_output: log.trace(f'State update: {self} steps={steps} jobs={jobs}') def set_current_image(self): if self.job == 'VAE' or self.job == 'Upscale': # avoid generating preview while vae is running return False from modules.shared import opts, cmd_opts if cmd_opts.lowvram or self.api or (opts.show_progress_every_n_steps <= 0): return False if (not self.disable_preview) and (abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps): return self.do_set_current_image() return False def do_set_current_image(self): if (self.current_latent is None) or self.disable_preview or (self.preview_job == self.job_no): return False from modules import shared, sd_samplers self.preview_job = self.job_no try: sample = self.current_latent self.current_image_sampling_step = self.sampling_step try: if self.current_noise_pred is not None and self.current_sigma is not None and self.current_sigma_next is not None: original_sample = sample - (self.current_noise_pred * (self.current_sigma_next-self.current_sigma)) if self.prediction_type in {"epsilon", "flow_prediction"}: sample = original_sample - (self.current_noise_pred * self.current_sigma) elif self.prediction_type == "v_prediction": sample = self.current_noise_pred * (-self.current_sigma / (self.current_sigma**2 + 1) ** 0.5) + (original_sample / (self.current_sigma**2 + 1)) # pylint: disable=invalid-unary-operand-type except Exception: pass # ignore sigma errors image = sd_samplers.samples_to_image_grid(sample) if shared.opts.show_progress_grid else sd_samplers.sample_to_image(sample) self.assign_current_image(image) self.preview_job = -1 return True except Exception as e: self.preview_job = -1 log.error(f'State image: last={self.id_live_preview} step={self.sampling_step} {e}') display(e, 'State image') return False def assign_current_image(self, image): self.current_image = image self.id_live_preview += 1