From 386245a26427a64f364f66f6fecd03b3bccfd7f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 10:25:35 +0300 Subject: split shared.py into multiple files; should resolve all circular reference import errors related to shared.py --- modules/shared_init.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 modules/shared_init.py (limited to 'modules/shared_init.py') diff --git a/modules/shared_init.py b/modules/shared_init.py new file mode 100644 index 00000000..e7fc18d2 --- /dev/null +++ b/modules/shared_init.py @@ -0,0 +1,51 @@ +import os + +import torch + +from modules import shared +from modules.shared import cmd_opts + +import sys +sys.setrecursionlimit(1000) + + +def initialize(): + """Initializes fields inside the shared module in a controlled manner. + + Should be called early because some other modules you can import mingt need these fields to be already set. + """ + + os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) + + from modules import options, shared_options + shared.options_templates = shared_options.options_templates + shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts) + if os.path.exists(shared.config_filename): + shared.opts.load(shared.config_filename) + + from modules import devices + devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ + (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) + + devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + + shared.device = devices.device + shared.weight_load_location = None if cmd_opts.lowram else "cpu" + + from modules import shared_state + shared.state = shared_state.State() + + from modules import styles + shared.prompt_styles = styles.StyleDatabase(shared.styles_filename) + + from modules import interrogate + shared.interrogator = interrogate.InterrogateModels("interrogate") + + from modules import shared_total_tqdm + shared.total_tqdm = shared_total_tqdm.TotalTQDM() + + from modules import memmon, devices + shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts) + shared.mem_mon.start() + -- cgit v1.2.3