From 386245a26427a64f364f66f6fecd03b3bccfd7f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 10:25:35 +0300 Subject: split shared.py into multiple files; should resolve all circular reference import errors related to shared.py --- modules/shared_state.py | 159 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 modules/shared_state.py (limited to 'modules/shared_state.py') diff --git a/modules/shared_state.py b/modules/shared_state.py new file mode 100644 index 00000000..3dc9c788 --- /dev/null +++ b/modules/shared_state.py @@ -0,0 +1,159 @@ +import datetime +import logging +import threading +import time + +from modules import errors, shared, devices +from typing import Optional + +log = logging.getLogger(__name__) + + +class State: + skipped = False + interrupted = False + job = "" + job_no = 0 + job_count = 0 + processing_has_refined_job_count = False + job_timestamp = '0' + sampling_step = 0 + sampling_steps = 0 + current_latent = None + current_image = None + current_image_sampling_step = 0 + id_live_preview = 0 + textinfo = None + time_start = None + server_start = None + _server_command_signal = threading.Event() + _server_command: Optional[str] = None + + def __init__(self): + self.server_start = time.time() + + @property + def need_restart(self) -> bool: + # Compatibility getter for need_restart. + return self.server_command == "restart" + + @need_restart.setter + def need_restart(self, value: bool) -> None: + # Compatibility setter for need_restart. + if value: + self.server_command = "restart" + + @property + def server_command(self): + return self._server_command + + @server_command.setter + def server_command(self, value: Optional[str]) -> None: + """ + Set the server command to `value` and signal that it's been set. + """ + self._server_command = value + self._server_command_signal.set() + + def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]: + """ + Wait for server command to get set; return and clear the value and signal. + """ + if self._server_command_signal.wait(timeout): + self._server_command_signal.clear() + req = self._server_command + self._server_command = None + return req + return None + + def request_restart(self) -> None: + self.interrupt() + self.server_command = "restart" + log.info("Received restart request") + + def skip(self): + self.skipped = True + log.info("Received skip request") + + def interrupt(self): + self.interrupted = True + log.info("Received interrupt request") + + def nextjob(self): + if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: + self.do_set_current_image() + + self.job_no += 1 + self.sampling_step = 0 + self.current_image_sampling_step = 0 + + def dict(self): + obj = { + "skipped": self.skipped, + "interrupted": self.interrupted, + "job": self.job, + "job_count": self.job_count, + "job_timestamp": self.job_timestamp, + "job_no": self.job_no, + "sampling_step": self.sampling_step, + "sampling_steps": self.sampling_steps, + } + + return obj + + def begin(self, job: str = "(unknown)"): + self.sampling_step = 0 + self.job_count = -1 + self.processing_has_refined_job_count = False + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.id_live_preview = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + self.time_start = time.time() + self.job = job + devices.torch_gc() + log.info("Starting job %s", job) + + def end(self): + duration = time.time() - self.time_start + log.info("Ending job %s (%.2f seconds)", self.job, duration) + self.job = "" + self.job_count = 0 + + devices.torch_gc() + + def set_current_image(self): + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" + if not shared.parallel_processing_allowed: + return + + if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1: + self.do_set_current_image() + + def do_set_current_image(self): + if self.current_latent is None: + return + + import modules.sd_samplers + + try: + if shared.opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. + # we silently ignore this error + errors.record_exception() + + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 -- cgit v1.2.3