From 4b8a192f680101de247dca79e48974b53bf961fe Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sat, 29 Oct 2022 16:36:43 +0900 Subject: add optimizer save option to shared.opts --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..065b893d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -286,6 +286,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From 3178c35224467893cf8dcedb1028c59c6c23db58 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:16:32 +0900 Subject: resolve conflicts --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 065b893d..959937d7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -285,7 +285,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), -- cgit v1.2.3 From 9b5f85ac83f864310fe19c9deab6670bad695b0d Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:18:04 +0900 Subject: first revert --- modules/shared.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 959937d7..7e8c552b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -286,7 +286,6 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From 7ea5956ad5fa925f92116e8a3bf78d7f6517b654 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:18:55 +0900 Subject: now add --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..7ecb40d8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,6 +309,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From f1b6ac64e451036fb4dfabe66d79488c56c06776 Mon Sep 17 00:00:00 2001 From: Kyu♥ <3ad4gum@gmail.com> Date: Wed, 2 Nov 2022 17:24:42 +0100 Subject: Added option to preview Created images on batch completion. --- modules/shared.py | 25 ++++++++++++++++--------- modules/ui.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..d4cf32a4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,9 @@ class State: self.interrupted = True def nextjob(self): + if opts.show_progress_every_n_steps == -1: + self.do_set_current_image() + self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 @@ -186,17 +189,21 @@ class State: """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: + self.do_set_current_image() + + def do_set_current_image(self): if not parallel_processing_allowed: return + if self.current_latent is None: + return + + if opts.show_progress_grid: + self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) + else: + self.current_image = sd_samplers.sample_to_image(self.current_latent) - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None: - if opts.show_progress_grid: - self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) - else: - self.current_image = sd_samplers.sample_to_image(self.current_latent) - - self.current_image_sampling_step = self.sampling_step - + self.current_image_sampling_step = self.sampling_step state = State() @@ -351,7 +358,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui.py b/modules/ui.py index 2609857e..29de1e10 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -276,7 +276,7 @@ def check_progress_call(id_part): image = gr_show(False) preview_visibility = gr_show(False) - if opts.show_progress_every_n_steps > 0: + if opts.show_progress_every_n_steps != 0: shared.state.set_current_image() image = shared.state.current_image -- cgit v1.2.3 From ccf1a15412ef6b518f9f54cc26a0ee5edf458108 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 10:16:19 +0300 Subject: add an option to enable installing extensions with --listen or --share --- modules/shared.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 024c771a..0a39cdf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision", parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") +parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) @@ -99,7 +100,7 @@ restricted_opts = { "outdir_save", } -cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) -- cgit v1.2.3 From 7278897982bfb640ee95f144c97ed25fb3f77ea3 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 17:12:28 +0900 Subject: Update shared.py --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 4d6e1c8b..6e7a02e0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,7 +309,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From 821e2b883dbb42a187bc37379175cd55b7cd7e81 Mon Sep 17 00:00:00 2001 From: TinkTheBoush Date: Fri, 4 Nov 2022 19:39:03 +0900 Subject: change option position to Training setting --- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/shared.py | 1 + modules/textual_inversion/dataset.py | 5 ++--- modules/textual_inversion/textual_inversion.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7630fb81..a11e01d6 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -331,7 +331,7 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -376,7 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/shared.py b/modules/shared.py index 1ccb269a..e1d9bdf1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -290,6 +290,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "shuffle_tags": OptionInfo(False, "Shuffleing tags by "," when create texts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e9d97cc1..df278dc2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,7 +24,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -33,7 +33,6 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) - self.shuffle_tags = shuffle_tags self.dataset = [] @@ -99,7 +98,7 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - if self.tag_shuffle: + if shared.opts.shuffle_tags: tags = filename_text.split(',') random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 82dde931..0aeb0459 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -272,7 +272,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) if unload: shared.sd_model.first_stage_model.to(devices.cpu) -- cgit v1.2.3 From f316280ad3634a2343b086a6de0bfcd473e18599 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 16:48:40 +0300 Subject: fix the error that prevents from setting some options --- modules/shared.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..962115f6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -406,7 +406,8 @@ class Options: if key in self.data or key in self.data_labels: assert not cmd_opts.freeze_settings, "changing settings is disabled" - comp_args = opts.data_labels[key].component_args + info = opts.data_labels.get(key, None) + comp_args = info.component_args if info else None if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: raise RuntimeError(f"not possible to set {key} because it is restricted") -- cgit v1.2.3 From b8435e632f7ba0da12a2c8e9c788dda519279d24 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sat, 5 Nov 2022 02:36:47 +0800 Subject: add --cors-allow-origins cmd opt --- modules/shared.py | 7 ++++--- webui.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..e83cbcdf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) cmd_opts = parser.parse_args() restricted_opts = { @@ -147,9 +148,9 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.show_progress_every_n_steps == -1: self.do_set_current_image() - + self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 @@ -198,7 +199,7 @@ class State: return if self.current_latent is None: return - + if opts.show_progress_grid: self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) else: diff --git a/webui.py b/webui.py index 81df09dd..3788af0b 100644 --- a/webui.py +++ b/webui.py @@ -5,6 +5,7 @@ import importlib import signal import threading from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path @@ -93,6 +94,11 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) +def setup_cors(app): + if cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + + def create_api(app): from modules.api.api import Api api = Api(app, queue_lock) @@ -114,6 +120,7 @@ def api_only(): initialize() app = FastAPI() + setup_cors(app) app.add_middleware(GZipMiddleware, minimum_size=1000) api = create_api(app) @@ -147,6 +154,8 @@ def webui(): # runnnig its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) if launch_api: -- cgit v1.2.3 From e9a5562b9b27a1a4f9c282637b111cefd9727a41 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 5 Nov 2022 04:06:51 -0500 Subject: add support for tls (gradio tls options) --- modules/shared.py | 3 +++ webui.py | 22 ++++++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 962115f6..7a20c3af 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,9 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) +parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) cmd_opts = parser.parse_args() restricted_opts = { diff --git a/webui.py b/webui.py index 81df09dd..d366f4ca 100644 --- a/webui.py +++ b/webui.py @@ -34,7 +34,7 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork queue_lock = threading.Lock() - +server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name def wrap_queued_call(func): def f(*args, **kwargs): @@ -85,6 +85,22 @@ def initialize(): shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print(f"path: '{cmd_opts.tls_keyfile}' {type(cmd_opts.tls_keyfile)}") + print(f"path: '{cmd_opts.tls_certfile}' {type(cmd_opts.tls_certfile)}") + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -131,8 +147,10 @@ def webui(): app, local_url, share_url = demo.launch( share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, + server_name=server_name, server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, debug=cmd_opts.gradio_debug, auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, inbrowser=cmd_opts.autolaunch, -- cgit v1.2.3 From a2a1a2f7270a865175f64475229838a8d64509ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 09:02:25 +0300 Subject: add ability to create extensions that add localizations --- javascript/ui.js | 2 ++ modules/localization.py | 6 ++++++ modules/scripts.py | 1 - modules/shared.py | 2 -- modules/ui.py | 3 +-- webui.py | 9 +++++---- 6 files changed, 14 insertions(+), 9 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index 7e116465..95cfd106 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -208,4 +208,6 @@ function update_token_counter(button_id) { function restart_reload(){ document.body.innerHTML='

Reloading...

'; setTimeout(function(){location.reload()},2000) + + return [] } diff --git a/modules/localization.py b/modules/localization.py index b1810cda..f6a6f2fb 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -3,6 +3,7 @@ import os import sys import traceback + localizations = {} @@ -16,6 +17,11 @@ def list_localizations(dirname): localizations[fn] = os.path.join(dirname, file) + from modules import scripts + for file in scripts.list_scripts("localizations", ".json"): + fn, ext = os.path.splitext(file.filename) + localizations[fn] = file.path + def localization_js(current_localization_name): fn = localizations.get(current_localization_name, None) diff --git a/modules/scripts.py b/modules/scripts.py index 366c90d7..637b2329 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,7 +3,6 @@ import sys import traceback from collections import namedtuple -import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing diff --git a/modules/shared.py b/modules/shared.py index 70b998ff..e8bacd3c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -localization.list_localizations(cmd_opts.localizations_dir) - def realesrgan_models_names(): import modules.realesrgan_model diff --git a/modules/ui.py b/modules/ui.py index 76ca9b07..23643c22 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,11 +1563,10 @@ def create_ui(wrap_gradio_gpu_call): shared.state.need_restart = True restart_gradio.click( - fn=request_restart, + _js='restart_reload', inputs=[], outputs=[], - _js='restart_reload' ) if column is not None: diff --git a/webui.py b/webui.py index a5a520f0..4342a962 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler, extensions +from modules import devices, sd_samplers, upscaler, extensions, localization import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -28,9 +28,7 @@ import modules.txt2img import modules.script_callbacks import modules.ui -from modules import devices from modules import modelloader -from modules.paths import script_path from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork @@ -64,6 +62,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers @@ -99,7 +98,6 @@ def initialize(): else: print("Running with TLS") - # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -185,6 +183,9 @@ def webui(): print('Reloading extensions') extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + print('Reloading custom scripts') modules.scripts.reload_scripts() print('Reloading modules: modules.ui') -- cgit v1.2.3 From a258fd60dbe2d68325339405a2aa72816d06d2fd Mon Sep 17 00:00:00 2001 From: Keavon Chambers Date: Mon, 7 Nov 2022 00:13:58 -0800 Subject: Add CORS-allow policy launch argument using regex --- modules/shared.py | 7 ++++--- webui.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c..55de286d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,12 +81,13 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") -parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) -parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) +parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) diff --git a/webui.py b/webui.py index f4f1d74d..066d94f7 100644 --- a/webui.py +++ b/webui.py @@ -107,8 +107,12 @@ def initialize(): def setup_cors(app): - if cmd_opts.cors_allow_origins: + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + elif cmd_opts.cors_allow_origins: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) def create_api(app): -- cgit v1.2.3 From 9ed4a126bd6421f91bf4a9bdd348b6aef0a378c6 Mon Sep 17 00:00:00 2001 From: kavorite Date: Mon, 7 Nov 2022 19:58:49 -0500 Subject: add gradio-inpaint-tool; color-sketch --- modules/img2img.py | 19 +++++++++++++------ modules/shared.py | 1 + modules/ui.py | 11 ++++++++++- 3 files changed, 24 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/img2img.py b/modules/img2img.py index be9f3653..00c6f827 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,18 +59,25 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 if is_inpaint: # Drawn mask if mask_mode == 0: - image = init_img_with_mask['image'] - mask = init_img_with_mask['mask'] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') - image = image.convert('RGB') + image = init_img_with_mask + is_mask_sketch = isinstance(image, dict) + if is_mask_sketch: + # Sketch: mask iff. not transparent + image, mask = image["image"], image["mask"] + mask = np.array(mask)[..., -1] > 0 + else: + # Color-sketch: mask iff. painted over + orig = init_img_with_mask_orig or image + mask = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(mask.astype(np.uint8) * 255, "L") + image = image.convert("RGB") # Uploaded mask else: image = init_img_inpaint diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..325e37d9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -71,6 +71,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") +parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 2609857e..db323e9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -840,8 +840,17 @@ def create_ui(wrap_gradio_gpu_call): init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) with gr.TabItem('Inpaint', id='inpaint'): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") -- cgit v1.2.3 From cfcadeae9a61e1aff32960864f90299412c86d5c Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 8 Nov 2022 10:03:56 -0600 Subject: Add option to preload extensions By creating a file called "preload.py" in an extension folder and declaring a preload(parser) method, we can add extra command-line args for an extension. --- modules/extensions.py | 23 ++++++++++++++++++++++- modules/shared.py | 5 ++++- 2 files changed, 26 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/extensions.py b/modules/extensions.py index 8e0977fd..544f3580 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,12 +1,12 @@ import os import sys import traceback +from importlib.machinery import SourceFileLoader import git from modules import paths, shared - extensions = [] extensions_dir = os.path.join(paths.script_path, "extensions") @@ -84,3 +84,24 @@ def list_extensions(): extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extensions.append(extension) + + +def preload_extensions(parser): + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + path = os.path.join(extensions_dir, dirname) + if not os.path.isdir(path): + continue + for file in os.listdir(path): + if "preload.py" in file: + full_file = os.path.join(path, file) + print(f"Got preload file: {full_file}") + + try: + ext = SourceFileLoader("preload", full_file).load_module() + parser = ext.preload(parser) + except Exception as e: + print(f"Exception preloading script: {e}") + return parser \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c..222ad4fb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -15,7 +15,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae +from modules import sd_samplers, sd_models, localization, sd_vae, extensions from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -91,7 +91,10 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +extensions.preload_extensions(parser) + cmd_opts = parser.parse_args() + restricted_opts = { "samples_filename_pattern", "directories_filename_pattern", -- cgit v1.2.3 From 6f8a807fe4eb41f6eb355c80fe96cd60b8e8a5a9 Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 09:22:49 +0900 Subject: Update shared.py --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 89f4d5ee..82da5ce0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -321,7 +321,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "shuffle_tags": OptionInfo(False, "Shuffleing tags by "," when create texts."), + "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), -- cgit v1.2.3 From 0959907f87314cbee8a80036ec8ae24c65888f7f Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:31:14 +0900 Subject: adding tag dropout option --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 82da5ce0..f2ea3baa 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -322,6 +322,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), + "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), -- cgit v1.2.3 From a1a376331c9ecbbee77b86daeaba44587cc56557 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:56:06 +0300 Subject: make existing script loading and new preload code use same code for loading modules limit extension preload scripts to just one file named preload.py --- modules/extensions.py | 21 --------------------- modules/script_loading.py | 34 ++++++++++++++++++++++++++++++++++ modules/scripts.py | 46 +++++++++++++++++----------------------------- modules/shared.py | 5 ++--- 4 files changed, 53 insertions(+), 53 deletions(-) create mode 100644 modules/script_loading.py (limited to 'modules/shared.py') diff --git a/modules/extensions.py b/modules/extensions.py index 544f3580..94ce479a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,6 @@ import os import sys import traceback -from importlib.machinery import SourceFileLoader import git @@ -85,23 +84,3 @@ def list_extensions(): extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extensions.append(extension) - -def preload_extensions(parser): - if not os.path.isdir(extensions_dir): - return - - for dirname in sorted(os.listdir(extensions_dir)): - path = os.path.join(extensions_dir, dirname) - if not os.path.isdir(path): - continue - for file in os.listdir(path): - if "preload.py" in file: - full_file = os.path.join(path, file) - print(f"Got preload file: {full_file}") - - try: - ext = SourceFileLoader("preload", full_file).load_module() - parser = ext.preload(parser) - except Exception as e: - print(f"Exception preloading script: {e}") - return parser \ No newline at end of file diff --git a/modules/script_loading.py b/modules/script_loading.py new file mode 100644 index 00000000..f93f0951 --- /dev/null +++ b/modules/script_loading.py @@ -0,0 +1,34 @@ +import os +import sys +import traceback +from types import ModuleType + + +def load_module(path): + with open(path, "r", encoding="utf8") as file: + text = file.read() + + compiled = compile(text, path, 'exec') + module = ModuleType(os.path.basename(path)) + exec(compiled, module.__dict__) + + return module + + +def preload_extensions(extensions_dir, parser): + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + preload_script = os.path.join(extensions_dir, dirname, "preload.py") + if not os.path.isfile(preload_script): + continue + + try: + module = load_module(preload_script) + if hasattr(module, 'preload'): + module.preload(parser) + + except Exception: + print(f"Error running preload() for {preload_script}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/scripts.py b/modules/scripts.py index 22d8908b..986b1914 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,7 +6,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions +from modules import shared, paths, script_callbacks, extensions, script_loading AlwaysVisible = object() @@ -161,13 +161,7 @@ def load_scripts(): sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - with open(scriptfile.path, "r", encoding="utf8") as file: - text = file.read() - - from types import ModuleType - compiled = compile(text, scriptfile.path, 'exec') - module = ModuleType(scriptfile.filename) - exec(compiled, module.__dict__) + module = script_loading.load_module(scriptfile.path) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -328,27 +322,21 @@ class ScriptRunner: def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): - with open(script.filename, "r", encoding="utf8") as file: - args_from = script.args_from - args_to = script.args_to - filename = script.filename - text = file.read() - - from types import ModuleType - - module = cache.get(filename, None) - if module is None: - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) - cache[filename] = module - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - self.scripts[si] = script_class() - self.scripts[si].filename = filename - self.scripts[si].args_from = args_from - self.scripts[si].args_to = args_to + args_from = script.args_from + args_to = script.args_to + filename = script.filename + + module = cache.get(filename, None) + if module is None: + module = script_loading.load_module(script.filename) + cache[filename] = module + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() diff --git a/modules/shared.py b/modules/shared.py index 17132e42..6936cbe0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,7 +3,6 @@ import datetime import json import os import sys -from collections import OrderedDict import time import gradio as gr @@ -15,7 +14,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae, extensions +from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) -extensions.preload_extensions(parser) +script_loading.preload_extensions(extensions.extensions_dir, parser) cmd_opts = parser.parse_args() -- cgit v1.2.3 From d20dbe47e06de7f6c0e65242a04c9bb1410ef7cb Mon Sep 17 00:00:00 2001 From: Xu Cuijie <975114697@qq.com> Date: Sun, 13 Nov 2022 10:31:03 +0800 Subject: fix the model name error of Real-ESRGAN in the opts default value --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..c46c29f7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -299,7 +299,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), -- cgit v1.2.3 From 3405acc6a4dcef2b73782a04924a9a12422e54f0 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Mon, 14 Nov 2022 14:07:13 -0600 Subject: Give --server-name priority over --listen and add check for --server-name in addition to --share and --listen --- modules/shared.py | 2 +- webui.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..c628b580 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -106,7 +106,7 @@ restricted_opts = { "outdir_save", } -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) diff --git a/webui.py b/webui.py index f4f1d74d..fc776669 100644 --- a/webui.py +++ b/webui.py @@ -33,7 +33,10 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork queue_lock = threading.Lock() -server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name +if cmd_opts.server_name: + server_name = cmd_opts.server_name +else: + server_name = "0.0.0.0" if cmd_opts.listen else None def wrap_queued_call(func): def f(*args, **kwargs): -- cgit v1.2.3 From 8f2ff861d31972d12de278075ea9c0c0deef99de Mon Sep 17 00:00:00 2001 From: Maiko Sinkyaet Tan Date: Tue, 15 Nov 2022 16:12:34 +0800 Subject: feat: add http basic authentication for api --- modules/api/api.py | 61 ++++++++++++++++++++++++++++++++++++------------------ modules/shared.py | 1 + 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..6bb01603 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,6 +5,9 @@ import uvicorn from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from secrets import compare_digest + import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -57,29 +60,47 @@ def encode_pil_to_base64(image): class Api: def __init__(self, app: FastAPI, queue_lock: Lock): + if shared.cmd_opts.api_auth: + self.credenticals = dict() + for auth in shared.cmd_opts.api_auth.split(","): + user, password = auth.split(":") + self.credenticals[user] = password + self.router = APIRouter() self.app = app self.queue_lock = queue_lock - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) - self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) - self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) - self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) - self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) - self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) - self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) - self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) - self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) - self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) - self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) + self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) + self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + + def add_api_route(self, path: str, endpoint, **kwargs): + if shared.cmd_opts.api_auth: + return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) + return self.app.add_api_route(path, endpoint, **kwargs) + + def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())): + if credenticals.username in self.credenticals: + if compare_digest(credenticals.password, self.credenticals[credenticals.username]): + return True + + raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..62d526fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,6 +81,7 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--api-auth", type=str, help='Set authentication for api like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -- cgit v1.2.3 From 0663706d4405b4f76ce653097f4f8989ee8b8684 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 13:47:03 +0700 Subject: Option to use selected VAE as default fallback instead of primary option --- modules/sd_vae.py | 25 ++++++++++++++++--------- modules/shared.py | 1 + webui.py | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..0b5f0213 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -83,7 +83,19 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): return vae_list -def resolve_vae(checkpoint_file, vae_file="auto"): +def get_vae_from_settings(vae_file="auto"): + # else, we load from settings, if not set to be default + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + return vae_file + + +def resolve_vae(checkpoint_file=None, vae_file="auto"): global first_load, vae_dict, vae_list # if vae_file argument is provided, it takes priority, but not saved @@ -98,14 +110,9 @@ def resolve_vae(checkpoint_file, vae_file="auto"): shared.opts.data['sd_vae'] = get_filename(vae_file) else: print("VAE provided as command line argument doesn't exist") - # else, we load from settings - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print("Selected VAE doesn't exist") + # fallback to selector in settings, if vae selector not set to act as default fallback + if not shared.opts.sd_vae_as_default: + vae_file = get_vae_from_settings(vae_file) # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): diff --git a/modules/shared.py b/modules/shared.py index 17132e42..b84767f0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -336,6 +336,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae_as_default": OptionInfo(False, "Use selected VAE as default fallback instead"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/webui.py b/webui.py index f4f1d74d..2cd3bae9 100644 --- a/webui.py +++ b/webui.py @@ -82,6 +82,7 @@ def initialize(): modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From c8f7b5cdd73969d3d5027ceb71cbbd83d557702b Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 11:11:14 +0700 Subject: Misc Misc --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 17132e42..a9daf800 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -335,7 +335,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.3 From d9fd4525a5d684100997130cc4132736bab1e4d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 11:09:44 +0300 Subject: change text for sd_vae_as_default that makes more sense to me --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 5528ab15..1c42641d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -335,7 +335,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), - "sd_vae_as_default": OptionInfo(False, "Use selected VAE as default fallback instead"), + "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.3 From 5a6387e189dc365c47a7979b9040d5b6fdd7ba43 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 15:15:24 +0300 Subject: make it possible to change models etc by editing options using API --- modules/api/api.py | 7 +++---- modules/shared.py | 17 +++++++++++++++++ modules/ui.py | 22 ++++------------------ 3 files changed, 24 insertions(+), 22 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py index bf700ed0..1e324d8d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -253,9 +253,8 @@ class Api: return options def set_config(self, req: Dict[str, Any]): - - for o in req: - setattr(shared.opts, o, req[o]) + for k, v in req.items(): + shared.opts.set(k, v) shared.opts.save(shared.config_filename) return @@ -264,7 +263,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] + return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/shared.py b/modules/shared.py index 84567c8e..58f53e54 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -437,6 +437,23 @@ class Options: return super(Options, self).__getattribute__(item) + def set(self, key, value): + """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" + + oldval = self.data.get(key, None) + if oldval == value: + return False + + try: + setattr(self, key, value) + except RuntimeError: + return False + + if self.data_labels[key].onchange is not None: + self.data_labels[key].onchange() + + return True + def save(self, filename): assert not cmd_opts.freeze_settings, "saving settings is disabled" diff --git a/modules/ui.py b/modules/ui.py index 5dd97754..bb090c62 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1484,16 +1484,9 @@ def create_ui(wrap_gradio_gpu_call): if comp == dummy_component: continue - oldval = opts.data.get(key, None) - try: - setattr(opts, key, value) - except RuntimeError: - continue - if oldval != value: - if opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - + if opts.set(key, value): changed.append(key) + try: opts.save(shared.config_filename) except RuntimeError: @@ -1504,15 +1497,8 @@ def create_ui(wrap_gradio_gpu_call): if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() - oldval = opts.data.get(key, None) - try: - setattr(opts, key, value) - except Exception: - return gr.update(value=oldval), opts.dumpjson() - - if oldval != value: - if opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() opts.save(shared.config_filename) -- cgit v1.2.3 From bd68e35de3b7cf7547ed97d8bdf60147402133cc Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 12:35:26 +0900 Subject: Gradient accumulation, autocast fix, new latent sampling method, etc --- modules/hypernetworks/hypernetwork.py | 269 +++++++++++---------- modules/sd_hijack.py | 9 +- modules/sd_hijack_checkpoint.py | 10 + modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 134 +++++++---- modules/textual_inversion/textual_inversion.py | 320 ++++++++++++++----------- modules/ui.py | 16 +- 7 files changed, 448 insertions(+), 313 deletions(-) create mode 100644 modules/sd_hijack_checkpoint.py (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fbb87dd1..3d3301b0 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -367,13 +367,13 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -403,28 +403,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() - ititial_step = hypernetwork.step or 0 - if ititial_step >= steps: + initial_step = hypernetwork.step or 0 + if initial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return hypernetwork, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - - size = len(ds.indexes) - loss_dict = defaultdict(lambda : deque(maxlen = 1024)) - losses = torch.zeros((size,)) - previous_mean_losses = [0] - previous_mean_loss = 0 - print("Mean loss of {} elements".format(size)) weights = hypernetwork.weights() for weight in weights: @@ -436,8 +432,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) - optimizer_name = 'AdamW' + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. try: @@ -446,131 +442,155 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log print("Cannot resume from saved optimizer!") print(e) + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + # size = len(ds.indexes) + # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + # losses = torch.zeros((size,)) + # previous_mean_losses = [0] + # previous_mean_loss = 0 + # print("Mean loss of {} elements".format(size)) + steps_without_grad = 0 last_saved_file = "" last_saved_image = "" forced_filename = "" - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if tag_drop_out != 0 or shuffle_tags: + shared.sd_model.cond_stage_model.to(devices.device) + c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) + shared.sd_model.cond_stage_model.to(devices.cpu) + else: + c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + del c + + _loss_step += loss.item() + scaler.scale(loss).backward() + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") + # scaler.unscale_(optimizer) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + scaler.step(optimizer) + scaler.update() + hypernetwork.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 + + steps_done = hypernetwork.step + 1 - optimizer.zero_grad() - weights[0].grad = None - loss.backward() - - if weights[0].grad is None: - steps_without_grad += 1 - else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - - optimizer.step() - - steps_done = hypernetwork.step + 1 - - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") - - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) - else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) - - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') - hypernetwork.optimizer_name = optimizer_name - if shared.opts.save_optimizer_state: - hypernetwork.optimizer_state_dict = optimizer.state_dict() - save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) - hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. - - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + epoch_num = hypernetwork.step // steps_per_epoch + epoch_step = hypernetwork.step % steps_per_epoch + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height - preview_text = p.prompt + preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.job_no = hypernetwork.step + shared.state.job_no = hypernetwork.step - shared.state.textinfo = f""" + shared.state.textinfo = f"""

-Loss: {previous_mean_loss:.7f}
+Loss: {loss_step:.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entries[0].cond_text)}
+Last prompt: {html.escape(batch.cond_text[0])}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - - report_statistics(loss_dict) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + finally: + pbar.leave = False + pbar.close() + #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name @@ -579,6 +599,9 @@ Last saved image: {html.escape(last_saved_image)}
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) del optimizer hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + return hypernetwork, filename def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..29c8b561 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -59,6 +59,10 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 +def fix_checkpoint(): + ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward class StableDiffusionModelHijack: fixes = None @@ -78,6 +82,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model apply_optimizations() + fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py new file mode 100644 index 00000000..5712972f --- /dev/null +++ b/modules/sd_hijack_checkpoint.py @@ -0,0 +1,10 @@ +from torch.utils.checkpoint import checkpoint + +def BasicTransformerBlock_forward(self, x, context=None): + return checkpoint(self._forward, x, context) + +def AttentionBlock_forward(self, x): + return checkpoint(self._forward, x) + +def ResBlock_forward(self, x, emb): + return checkpoint(self._forward, x, emb) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index a4457305..3704ce23 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), - "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}), + "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index eb75c376..d594b49d 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,7 +3,7 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from torchvision import transforms import random @@ -11,25 +11,28 @@ import tqdm from modules import devices, shared import re +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, latent=None, filename_text=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename - self.latent = latent self.filename_text = filename_text - self.cond = None - self.cond_text = None + self.latent_dist = latent_dist + self.latent_sample = latent_sample + self.cond = cond + self.cond_text = cond_text + self.pixel_values = pixel_values class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None - + self.placeholder_token = placeholder_token - self.batch_size = batch_size self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) @@ -45,11 +48,16 @@ class PersonalizedBase(Dataset): assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - cond_model = shared.sd_model.cond_stage_model - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + + + self.shuffle_tags = shuffle_tags + self.tag_drop_out = tag_drop_out + print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): + if shared.state.interrupted: + raise Exception("inturrupted") try: image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: @@ -71,37 +79,58 @@ class PersonalizedBase(Dataset): npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) - torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) - torchdata = torch.moveaxis(torchdata, 2, 0) - - init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() - init_latent = init_latent.to(devices.cpu) - - entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) - - if include_cond: + torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) + latent_sample = None + + with torch.autocast("cuda"): + latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) + + if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + latent_sampling_method = "once" + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "deterministic": + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + + if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) - entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - self.dataset.append(entry) - - assert len(self.dataset) > 0, "No images have been found in the dataset." - self.length = len(self.dataset) * repeats // batch_size + if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): + with torch.autocast("cuda"): + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + # elif not include_cond: + # _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text]) + # max_n = token_count // 75 + # index_list = [ [] for _ in range(max_n + 1) ] + # for n, (z, _) in hijack_fixes[0]: + # index_list[n].append(z) + # with torch.autocast("cuda"): + # entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + # entry.emb_index = index_list - self.dataset_length = len(self.dataset) - self.indexes = None - self.shuffle() + self.dataset.append(entry) + del torchdata + del latent_dist + del latent_sample - def shuffle(self): - self.indexes = np.random.permutation(self.dataset_length) + self.length = len(self.dataset) + assert self.length > 0, "No images have been found in the dataset." + self.batch_size = min(batch_size, self.length) + self.gradient_step = min(gradient_step, self.length // self.batch_size) + self.latent_sampling_method = latent_sampling_method def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') - if shared.opts.tag_drop_out != 0: - tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] - if shared.opts.shuffle_tags: + if self.tag_drop_out != 0: + tags = [t for t in tags if random.random() > self.tag_drop_out] + if self.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) return text @@ -110,19 +139,28 @@ class PersonalizedBase(Dataset): return self.length def __getitem__(self, i): - res = [] - - for j in range(self.batch_size): - position = i * self.batch_size + j - if position % len(self.indexes) == 0: - self.shuffle() - - index = self.indexes[position % len(self.indexes)] - entry = self.dataset[index] - - if entry.cond is None: - entry.cond_text = self.create_text(entry.filename_text) - - res.append(entry) - - return res + entry = self.dataset[i] + if self.tag_drop_out != 0 or self.shuffle_tags: + entry.cond_text = self.create_text(entry.filename_text) + if self.latent_sampling_method == "random": + entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist) + return entry + +class PersonalizedDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs) + self.collate_fn = collate_wrapper + + +class BatchLoader: + def __init__(self, data): + self.cond_text = [entry.cond_text for entry in data] + self.cond = [entry.cond for entry in data] + self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + + def pin_memory(self): + self.latent_sample = self.latent_sample.pin_memory() + return self + +def collate_wrapper(batch): + return BatchLoader(batch) \ No newline at end of file diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5e4d8688..1d5e3a32 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -184,7 +184,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return - if (step + 1) % shared.opts.training_write_csv_every != 0: + if step % shared.opts.training_write_csv_every != 0: return write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True @@ -194,21 +194,23 @@ def write_loss(log_directory, filename, step, epoch_len, values): if write_csv_header: csv_writer.writeheader() - epoch = step // epoch_len - epoch_step = step % epoch_len + epoch = (step - 1) // epoch_len + epoch_step = (step - 1) % epoch_len csv_writer.writerow({ - "step": step + 1, + "step": step, "epoch": epoch, - "epoch_step": epoch_step + 1, + "epoch_step": epoch_step, **values, }) -def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" assert batch_size > 0, "Batch size must be positive" + assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" + assert gradient_step > 0, "Gradient accumulation step must be positive" assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" @@ -224,10 +226,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps @@ -255,161 +257,205 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc else: images_embeds_dir = None - cond_model = shared.sd_model.cond_stage_model - hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] checkpoint = sd_models.select_checkpoint() - ititial_step = embedding.step or 0 - if ititial_step >= steps: + initial_step = embedding.step or 0 + if initial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - - # dataset loading may take a while, so input validations and early returns should be done before this + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + + latent_sampling_method = ds.latent_sampling_method + + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False) + if unload: shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) + scaler = torch.cuda.amp.GradScaler() - losses = torch.zeros((32,)) + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + last_saved_file = "" last_saved_image = "" forced_filename = "" embedding_yet_to_be_embedded = False - - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entries in pbar: - embedding.step = i + ititial_step - - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - - losses[embedding.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - steps_done = embedding.step + 1 - - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) - - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding_name_every = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') - save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) - embedding_yet_to_be_embedded = True - - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height - - preview_text = p.prompt - - processed = processing.process_images(p) - image = processed.images[0] - - if unload: - shared.sd_model.first_stage_model.to(devices.cpu) - - shared.state.current_image = image - - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) - - title = "<{}>".format(data.get('name', '???')) - - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' - - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) - - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) - - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False - - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = embedding.step - - shared.state.textinfo = f""" + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + # c = stack_conds(batch.cond).to(devices.device) + # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) + # print(mask) + # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory) + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + + _loss_step += loss.item() + scaler.scale(loss).backward() + + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + #scaler.unscale_(optimizer) + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + #torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0) + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + scaler.step(optimizer) + scaler.update() + embedding.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // steps_per_epoch + epoch_step = embedding.step % steps_per_epoch + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + #if shared.opts.save_optimizer_state: + #embedding.optimizer_state_dict = optimizer.state_dict() + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None + + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) + + title = "<{}>".format(data.get('name', '???')) + + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' + + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) + + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False + + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f"""

-Loss: {losses.mean():.7f}
+Loss: {loss_step:.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entries[0].cond_text)}
+Last prompt: {html.escape(batch.cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) - shared.sd_model.first_stage_model.to(devices.device) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + pass + finally: + pbar.leave = False + pbar.close() + shared.sd_model.first_stage_model.to(devices.device) return embedding, filename diff --git a/modules/ui.py b/modules/ui.py index a5953fce..9d2a1cbf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1262,7 +1262,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): with gr.Row(): interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') + run_preprocess = gr.Button(value="Preprocess", variant='primary') process_split.change( fn=lambda show: gr_show(show), @@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call): hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") batch_size = gr.Number(label='Batch size', value=1, precision=0) + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) @@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + with gr.Row(): + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) + with gr.Row(): + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call): train_embedding_name, embedding_learn_rate, batch_size, + gradient_step, dataset_directory, log_directory, training_width, training_height, steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, create_image_every, save_embedding_every, template_file, @@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call): train_hypernetwork_name, hypernetwork_learn_rate, batch_size, + gradient_step, dataset_directory, log_directory, training_width, training_height, steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From c81d440d876dfd2ab3560410f37442ef56fc6632 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 20 Nov 2022 16:39:20 +0300 Subject: moved deepdanbooru to pure pytorch implementation --- README.md | 2 +- launch.py | 5 - modules/api/api.py | 10 +- modules/deepbooru.py | 258 +++++------- modules/deepbooru_model.py | 676 ++++++++++++++++++++++++++++++++ modules/shared.py | 2 +- modules/textual_inversion/preprocess.py | 12 +- modules/ui.py | 7 +- 8 files changed, 777 insertions(+), 195 deletions(-) create mode 100644 modules/deepbooru_model.py (limited to 'modules/shared.py') diff --git a/README.md b/README.md index 33508f31..5f5ab3aa 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - separate prompts using uppercase `AND` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) -- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) +- DeepDanbooru integration, creates danbooru style tags for anime prompts - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI - Generate forever option diff --git a/launch.py b/launch.py index 0f84b5d1..d2f1055c 100644 --- a/launch.py +++ b/launch.py @@ -134,7 +134,6 @@ def prepare_enviroment(): gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") - deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff") xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') @@ -158,7 +157,6 @@ def prepare_enviroment(): sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests = extract_arg(sys.argv, '--tests') xformers = '--xformers' in sys.argv - deepdanbooru = '--deepdanbooru' in sys.argv ngrok = '--ngrok' in sys.argv try: @@ -193,9 +191,6 @@ def prepare_enviroment(): elif platform.system() == "Linux": run_pip("install xformers", "xformers") - if not is_installed("deepdanbooru") and deepdanbooru: - run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") - if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") diff --git a/modules/api/api.py b/modules/api/api.py index 79b2c818..7a567be3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,7 +9,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers +from modules import sd_samplers, deepbooru from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras, run_pnginfo @@ -18,9 +18,6 @@ from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List -if shared.cmd_opts.deepdanbooru: - from modules.deepbooru import get_deepbooru_tags - def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) @@ -245,10 +242,7 @@ class Api: if interrogatereq.model == "clip": processed = shared.interrogator.interrogate(img) elif interrogatereq.model == "deepdanbooru": - if shared.cmd_opts.deepdanbooru: - processed = get_deepbooru_tags(img) - else: - raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.") + processed = deepbooru.model.tag(img) else: raise HTTPException(status_code=404, detail="Model not found") diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 8bbc90a4..b9066d81 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,173 +1,97 @@ -import os.path -from concurrent.futures import ProcessPoolExecutor -import multiprocessing -import time +import os import re +import torch +from PIL import Image +import numpy as np + +from modules import modelloader, paths, deepbooru_model, devices, images, shared + re_special = re.compile(r'([\\()])') -def get_deepbooru_tags(pil_image): - """ - This method is for running only one image at a time for simple use. Used to the img2img interrogate. - """ - from modules import shared # prevents circular reference - - try: - create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) - return get_tags_from_process(pil_image) - finally: - release_process() - - -OPT_INCLUDE_RANKS = "include_ranks" -def create_deepbooru_opts(): - from modules import shared - - return { - "use_spaces": shared.opts.deepbooru_use_spaces, - "use_escape": shared.opts.deepbooru_escape, - "alpha_sort": shared.opts.deepbooru_sort_alpha, - OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks, - } - - -def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts): - model, tags = get_deepbooru_tags_model() - while True: # while process is running, keep monitoring queue for new image - pil_image = queue.get() - if pil_image == "QUIT": - break - else: - deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts) - - -def create_deepbooru_process(threshold, deepbooru_opts): - """ - Creates deepbooru process. A queue is created to send images into the process. This enables multiple images - to be processed in a row without reloading the model or creating a new process. To return the data, a shared - dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned - to the dictionary and the method adding the image to the queue should wait for this value to be updated with - the tags. - """ - from modules import shared # prevents circular reference - context = multiprocessing.get_context("spawn") - shared.deepbooru_process_manager = context.Manager() - shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() - shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() - shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) - shared.deepbooru_process.start() - - -def get_tags_from_process(image): - from modules import shared - - shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process_queue.put(image) - while shared.deepbooru_process_return["value"] == -1: - time.sleep(0.2) - caption = shared.deepbooru_process_return["value"] - shared.deepbooru_process_return["value"] = -1 - - return caption - - -def release_process(): - """ - Stops the deepbooru process to return used memory - """ - from modules import shared # prevents circular reference - shared.deepbooru_process_queue.put("QUIT") - shared.deepbooru_process.join() - shared.deepbooru_process_queue = None - shared.deepbooru_process = None - shared.deepbooru_process_return = None - shared.deepbooru_process_manager = None - -def get_deepbooru_tags_model(): - import deepdanbooru as dd - import tensorflow as tf - import numpy as np - this_folder = os.path.dirname(__file__) - model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) - if not os.path.exists(os.path.join(model_path, 'project.json')): - # there is no point importing these every time - import zipfile - from basicsr.utils.download_util import load_file_from_url - load_file_from_url( - r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", - model_path) - with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: - zip_ref.extractall(model_path) - os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) - - tags = dd.project.load_tags_from_project(model_path) - model = dd.project.load_model_from_project( - model_path, compile_model=False - ) - return model, tags - - -def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts): - import deepdanbooru as dd - import tensorflow as tf - import numpy as np - - alpha_sort = deepbooru_opts['alpha_sort'] - use_spaces = deepbooru_opts['use_spaces'] - use_escape = deepbooru_opts['use_escape'] - include_ranks = deepbooru_opts['include_ranks'] - - width = model.input_shape[2] - height = model.input_shape[1] - image = np.array(pil_image) - image = tf.image.resize( - image, - size=(height, width), - method=tf.image.ResizeMethod.AREA, - preserve_aspect_ratio=True, - ) - image = image.numpy() # EagerTensor to np.array - image = dd.image.transform_and_pad_image(image, width, height) - image = image / 255.0 - image_shape = image.shape - image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) - - y = model.predict(image)[0] - - result_dict = {} - - for i, tag in enumerate(tags): - result_dict[tag] = y[i] - - unsorted_tags_in_theshold = [] - result_tags_print = [] - for tag in tags: - if result_dict[tag] >= threshold: + +class DeepDanbooru: + def __init__(self): + self.model = None + + def load(self): + if self.model is not None: + return + + files = modelloader.load_models( + model_path=os.path.join(paths.models_path, "torch_deepdanbooru"), + model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', + ext_filter=".pt", + download_name='model-resnet_custom_v3.pt', + ) + + self.model = deepbooru_model.DeepDanbooruModel() + self.model.load_state_dict(torch.load(files[0], map_location="cpu")) + + self.model.eval() + self.model.to(devices.cpu, devices.dtype) + + def start(self): + self.load() + self.model.to(devices.device) + + def stop(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.model.to(devices.cpu) + devices.torch_gc() + + def tag(self, pil_image): + self.start() + res = self.tag_multi(pil_image) + self.stop() + + return res + + def tag_multi(self, pil_image, force_disable_ranks=False): + threshold = shared.opts.interrogate_deepbooru_score_threshold + use_spaces = shared.opts.deepbooru_use_spaces + use_escape = shared.opts.deepbooru_escape + alpha_sort = shared.opts.deepbooru_sort_alpha + include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks + + pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512) + a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 + + with torch.no_grad(), devices.autocast(): + x = torch.from_numpy(a).cuda() + y = self.model(x)[0].detach().cpu().numpy() + + probability_dict = {} + + for tag, probability in zip(self.model.tags, y): + if probability < threshold: + continue + if tag.startswith("rating:"): continue - unsorted_tags_in_theshold.append((result_dict[tag], tag)) - result_tags_print.append(f'{result_dict[tag]} {tag}') - - # sort tags - result_tags_out = [] - sort_ndx = 0 - if alpha_sort: - sort_ndx = 1 - - # sort by reverse by likelihood and normal for alpha, and format tag text as requested - unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) - for weight, tag in unsorted_tags_in_theshold: - tag_outformat = tag - if use_spaces: - tag_outformat = tag_outformat.replace('_', ' ') - if use_escape: - tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) - if include_ranks: - tag_outformat = f"({tag_outformat}:{weight:.3f})" - - result_tags_out.append(tag_outformat) - - print('\n'.join(sorted(result_tags_print, reverse=True))) - - return ', '.join(result_tags_out) + + probability_dict[tag] = probability + + if alpha_sort: + tags = sorted(probability_dict) + else: + tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])] + + res = [] + + for tag in tags: + probability = probability_dict[tag] + tag_outformat = tag + if use_spaces: + tag_outformat = tag_outformat.replace('_', ' ') + if use_escape: + tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) + if include_ranks: + tag_outformat = f"({tag_outformat}:{probability:.3f})" + + res.append(tag_outformat) + + return ", ".join(res) + + +model = DeepDanbooru() diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py new file mode 100644 index 00000000..edd40c81 --- /dev/null +++ b/modules/deepbooru_model.py @@ -0,0 +1,676 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more + + +class DeepDanbooruModel(nn.Module): + def __init__(self): + super(DeepDanbooruModel, self).__init__() + + self.tags = [] + + self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2)) + self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)) + self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) + self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64) + self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64) + self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) + self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64) + self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64) + self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) + self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64) + self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64) + self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) + self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2)) + self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128) + self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2)) + self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) + self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) + self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) + self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2)) + self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256) + self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2)) + self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2)) + self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2)) + self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) + self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) + self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) + self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2)) + self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512) + self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2)) + self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048) + self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512) + self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512) + self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048) + self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512) + self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512) + self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048) + self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2)) + self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024) + self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2)) + self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096) + self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024) + self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024) + self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096) + self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024) + self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024) + self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096) + self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False) + + def forward(self, *inputs): + t_358, = inputs + t_359 = t_358.permute(*[0, 3, 1, 2]) + t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) + t_360 = self.n_Conv_0(t_359_padded) + t_361 = F.relu(t_360) + t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) + t_362 = self.n_MaxPool_0(t_361) + t_363 = self.n_Conv_1(t_362) + t_364 = self.n_Conv_2(t_362) + t_365 = F.relu(t_364) + t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0) + t_366 = self.n_Conv_3(t_365_padded) + t_367 = F.relu(t_366) + t_368 = self.n_Conv_4(t_367) + t_369 = torch.add(t_368, t_363) + t_370 = F.relu(t_369) + t_371 = self.n_Conv_5(t_370) + t_372 = F.relu(t_371) + t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0) + t_373 = self.n_Conv_6(t_372_padded) + t_374 = F.relu(t_373) + t_375 = self.n_Conv_7(t_374) + t_376 = torch.add(t_375, t_370) + t_377 = F.relu(t_376) + t_378 = self.n_Conv_8(t_377) + t_379 = F.relu(t_378) + t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0) + t_380 = self.n_Conv_9(t_379_padded) + t_381 = F.relu(t_380) + t_382 = self.n_Conv_10(t_381) + t_383 = torch.add(t_382, t_377) + t_384 = F.relu(t_383) + t_385 = self.n_Conv_11(t_384) + t_386 = self.n_Conv_12(t_384) + t_387 = F.relu(t_386) + t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0) + t_388 = self.n_Conv_13(t_387_padded) + t_389 = F.relu(t_388) + t_390 = self.n_Conv_14(t_389) + t_391 = torch.add(t_390, t_385) + t_392 = F.relu(t_391) + t_393 = self.n_Conv_15(t_392) + t_394 = F.relu(t_393) + t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0) + t_395 = self.n_Conv_16(t_394_padded) + t_396 = F.relu(t_395) + t_397 = self.n_Conv_17(t_396) + t_398 = torch.add(t_397, t_392) + t_399 = F.relu(t_398) + t_400 = self.n_Conv_18(t_399) + t_401 = F.relu(t_400) + t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0) + t_402 = self.n_Conv_19(t_401_padded) + t_403 = F.relu(t_402) + t_404 = self.n_Conv_20(t_403) + t_405 = torch.add(t_404, t_399) + t_406 = F.relu(t_405) + t_407 = self.n_Conv_21(t_406) + t_408 = F.relu(t_407) + t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0) + t_409 = self.n_Conv_22(t_408_padded) + t_410 = F.relu(t_409) + t_411 = self.n_Conv_23(t_410) + t_412 = torch.add(t_411, t_406) + t_413 = F.relu(t_412) + t_414 = self.n_Conv_24(t_413) + t_415 = F.relu(t_414) + t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0) + t_416 = self.n_Conv_25(t_415_padded) + t_417 = F.relu(t_416) + t_418 = self.n_Conv_26(t_417) + t_419 = torch.add(t_418, t_413) + t_420 = F.relu(t_419) + t_421 = self.n_Conv_27(t_420) + t_422 = F.relu(t_421) + t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0) + t_423 = self.n_Conv_28(t_422_padded) + t_424 = F.relu(t_423) + t_425 = self.n_Conv_29(t_424) + t_426 = torch.add(t_425, t_420) + t_427 = F.relu(t_426) + t_428 = self.n_Conv_30(t_427) + t_429 = F.relu(t_428) + t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0) + t_430 = self.n_Conv_31(t_429_padded) + t_431 = F.relu(t_430) + t_432 = self.n_Conv_32(t_431) + t_433 = torch.add(t_432, t_427) + t_434 = F.relu(t_433) + t_435 = self.n_Conv_33(t_434) + t_436 = F.relu(t_435) + t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0) + t_437 = self.n_Conv_34(t_436_padded) + t_438 = F.relu(t_437) + t_439 = self.n_Conv_35(t_438) + t_440 = torch.add(t_439, t_434) + t_441 = F.relu(t_440) + t_442 = self.n_Conv_36(t_441) + t_443 = self.n_Conv_37(t_441) + t_444 = F.relu(t_443) + t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0) + t_445 = self.n_Conv_38(t_444_padded) + t_446 = F.relu(t_445) + t_447 = self.n_Conv_39(t_446) + t_448 = torch.add(t_447, t_442) + t_449 = F.relu(t_448) + t_450 = self.n_Conv_40(t_449) + t_451 = F.relu(t_450) + t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0) + t_452 = self.n_Conv_41(t_451_padded) + t_453 = F.relu(t_452) + t_454 = self.n_Conv_42(t_453) + t_455 = torch.add(t_454, t_449) + t_456 = F.relu(t_455) + t_457 = self.n_Conv_43(t_456) + t_458 = F.relu(t_457) + t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0) + t_459 = self.n_Conv_44(t_458_padded) + t_460 = F.relu(t_459) + t_461 = self.n_Conv_45(t_460) + t_462 = torch.add(t_461, t_456) + t_463 = F.relu(t_462) + t_464 = self.n_Conv_46(t_463) + t_465 = F.relu(t_464) + t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0) + t_466 = self.n_Conv_47(t_465_padded) + t_467 = F.relu(t_466) + t_468 = self.n_Conv_48(t_467) + t_469 = torch.add(t_468, t_463) + t_470 = F.relu(t_469) + t_471 = self.n_Conv_49(t_470) + t_472 = F.relu(t_471) + t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0) + t_473 = self.n_Conv_50(t_472_padded) + t_474 = F.relu(t_473) + t_475 = self.n_Conv_51(t_474) + t_476 = torch.add(t_475, t_470) + t_477 = F.relu(t_476) + t_478 = self.n_Conv_52(t_477) + t_479 = F.relu(t_478) + t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0) + t_480 = self.n_Conv_53(t_479_padded) + t_481 = F.relu(t_480) + t_482 = self.n_Conv_54(t_481) + t_483 = torch.add(t_482, t_477) + t_484 = F.relu(t_483) + t_485 = self.n_Conv_55(t_484) + t_486 = F.relu(t_485) + t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0) + t_487 = self.n_Conv_56(t_486_padded) + t_488 = F.relu(t_487) + t_489 = self.n_Conv_57(t_488) + t_490 = torch.add(t_489, t_484) + t_491 = F.relu(t_490) + t_492 = self.n_Conv_58(t_491) + t_493 = F.relu(t_492) + t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0) + t_494 = self.n_Conv_59(t_493_padded) + t_495 = F.relu(t_494) + t_496 = self.n_Conv_60(t_495) + t_497 = torch.add(t_496, t_491) + t_498 = F.relu(t_497) + t_499 = self.n_Conv_61(t_498) + t_500 = F.relu(t_499) + t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0) + t_501 = self.n_Conv_62(t_500_padded) + t_502 = F.relu(t_501) + t_503 = self.n_Conv_63(t_502) + t_504 = torch.add(t_503, t_498) + t_505 = F.relu(t_504) + t_506 = self.n_Conv_64(t_505) + t_507 = F.relu(t_506) + t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0) + t_508 = self.n_Conv_65(t_507_padded) + t_509 = F.relu(t_508) + t_510 = self.n_Conv_66(t_509) + t_511 = torch.add(t_510, t_505) + t_512 = F.relu(t_511) + t_513 = self.n_Conv_67(t_512) + t_514 = F.relu(t_513) + t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0) + t_515 = self.n_Conv_68(t_514_padded) + t_516 = F.relu(t_515) + t_517 = self.n_Conv_69(t_516) + t_518 = torch.add(t_517, t_512) + t_519 = F.relu(t_518) + t_520 = self.n_Conv_70(t_519) + t_521 = F.relu(t_520) + t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0) + t_522 = self.n_Conv_71(t_521_padded) + t_523 = F.relu(t_522) + t_524 = self.n_Conv_72(t_523) + t_525 = torch.add(t_524, t_519) + t_526 = F.relu(t_525) + t_527 = self.n_Conv_73(t_526) + t_528 = F.relu(t_527) + t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0) + t_529 = self.n_Conv_74(t_528_padded) + t_530 = F.relu(t_529) + t_531 = self.n_Conv_75(t_530) + t_532 = torch.add(t_531, t_526) + t_533 = F.relu(t_532) + t_534 = self.n_Conv_76(t_533) + t_535 = F.relu(t_534) + t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0) + t_536 = self.n_Conv_77(t_535_padded) + t_537 = F.relu(t_536) + t_538 = self.n_Conv_78(t_537) + t_539 = torch.add(t_538, t_533) + t_540 = F.relu(t_539) + t_541 = self.n_Conv_79(t_540) + t_542 = F.relu(t_541) + t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0) + t_543 = self.n_Conv_80(t_542_padded) + t_544 = F.relu(t_543) + t_545 = self.n_Conv_81(t_544) + t_546 = torch.add(t_545, t_540) + t_547 = F.relu(t_546) + t_548 = self.n_Conv_82(t_547) + t_549 = F.relu(t_548) + t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0) + t_550 = self.n_Conv_83(t_549_padded) + t_551 = F.relu(t_550) + t_552 = self.n_Conv_84(t_551) + t_553 = torch.add(t_552, t_547) + t_554 = F.relu(t_553) + t_555 = self.n_Conv_85(t_554) + t_556 = F.relu(t_555) + t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0) + t_557 = self.n_Conv_86(t_556_padded) + t_558 = F.relu(t_557) + t_559 = self.n_Conv_87(t_558) + t_560 = torch.add(t_559, t_554) + t_561 = F.relu(t_560) + t_562 = self.n_Conv_88(t_561) + t_563 = F.relu(t_562) + t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0) + t_564 = self.n_Conv_89(t_563_padded) + t_565 = F.relu(t_564) + t_566 = self.n_Conv_90(t_565) + t_567 = torch.add(t_566, t_561) + t_568 = F.relu(t_567) + t_569 = self.n_Conv_91(t_568) + t_570 = F.relu(t_569) + t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0) + t_571 = self.n_Conv_92(t_570_padded) + t_572 = F.relu(t_571) + t_573 = self.n_Conv_93(t_572) + t_574 = torch.add(t_573, t_568) + t_575 = F.relu(t_574) + t_576 = self.n_Conv_94(t_575) + t_577 = F.relu(t_576) + t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0) + t_578 = self.n_Conv_95(t_577_padded) + t_579 = F.relu(t_578) + t_580 = self.n_Conv_96(t_579) + t_581 = torch.add(t_580, t_575) + t_582 = F.relu(t_581) + t_583 = self.n_Conv_97(t_582) + t_584 = F.relu(t_583) + t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0) + t_585 = self.n_Conv_98(t_584_padded) + t_586 = F.relu(t_585) + t_587 = self.n_Conv_99(t_586) + t_588 = self.n_Conv_100(t_582) + t_589 = torch.add(t_587, t_588) + t_590 = F.relu(t_589) + t_591 = self.n_Conv_101(t_590) + t_592 = F.relu(t_591) + t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0) + t_593 = self.n_Conv_102(t_592_padded) + t_594 = F.relu(t_593) + t_595 = self.n_Conv_103(t_594) + t_596 = torch.add(t_595, t_590) + t_597 = F.relu(t_596) + t_598 = self.n_Conv_104(t_597) + t_599 = F.relu(t_598) + t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0) + t_600 = self.n_Conv_105(t_599_padded) + t_601 = F.relu(t_600) + t_602 = self.n_Conv_106(t_601) + t_603 = torch.add(t_602, t_597) + t_604 = F.relu(t_603) + t_605 = self.n_Conv_107(t_604) + t_606 = F.relu(t_605) + t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0) + t_607 = self.n_Conv_108(t_606_padded) + t_608 = F.relu(t_607) + t_609 = self.n_Conv_109(t_608) + t_610 = torch.add(t_609, t_604) + t_611 = F.relu(t_610) + t_612 = self.n_Conv_110(t_611) + t_613 = F.relu(t_612) + t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0) + t_614 = self.n_Conv_111(t_613_padded) + t_615 = F.relu(t_614) + t_616 = self.n_Conv_112(t_615) + t_617 = torch.add(t_616, t_611) + t_618 = F.relu(t_617) + t_619 = self.n_Conv_113(t_618) + t_620 = F.relu(t_619) + t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0) + t_621 = self.n_Conv_114(t_620_padded) + t_622 = F.relu(t_621) + t_623 = self.n_Conv_115(t_622) + t_624 = torch.add(t_623, t_618) + t_625 = F.relu(t_624) + t_626 = self.n_Conv_116(t_625) + t_627 = F.relu(t_626) + t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0) + t_628 = self.n_Conv_117(t_627_padded) + t_629 = F.relu(t_628) + t_630 = self.n_Conv_118(t_629) + t_631 = torch.add(t_630, t_625) + t_632 = F.relu(t_631) + t_633 = self.n_Conv_119(t_632) + t_634 = F.relu(t_633) + t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0) + t_635 = self.n_Conv_120(t_634_padded) + t_636 = F.relu(t_635) + t_637 = self.n_Conv_121(t_636) + t_638 = torch.add(t_637, t_632) + t_639 = F.relu(t_638) + t_640 = self.n_Conv_122(t_639) + t_641 = F.relu(t_640) + t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0) + t_642 = self.n_Conv_123(t_641_padded) + t_643 = F.relu(t_642) + t_644 = self.n_Conv_124(t_643) + t_645 = torch.add(t_644, t_639) + t_646 = F.relu(t_645) + t_647 = self.n_Conv_125(t_646) + t_648 = F.relu(t_647) + t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0) + t_649 = self.n_Conv_126(t_648_padded) + t_650 = F.relu(t_649) + t_651 = self.n_Conv_127(t_650) + t_652 = torch.add(t_651, t_646) + t_653 = F.relu(t_652) + t_654 = self.n_Conv_128(t_653) + t_655 = F.relu(t_654) + t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0) + t_656 = self.n_Conv_129(t_655_padded) + t_657 = F.relu(t_656) + t_658 = self.n_Conv_130(t_657) + t_659 = torch.add(t_658, t_653) + t_660 = F.relu(t_659) + t_661 = self.n_Conv_131(t_660) + t_662 = F.relu(t_661) + t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0) + t_663 = self.n_Conv_132(t_662_padded) + t_664 = F.relu(t_663) + t_665 = self.n_Conv_133(t_664) + t_666 = torch.add(t_665, t_660) + t_667 = F.relu(t_666) + t_668 = self.n_Conv_134(t_667) + t_669 = F.relu(t_668) + t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0) + t_670 = self.n_Conv_135(t_669_padded) + t_671 = F.relu(t_670) + t_672 = self.n_Conv_136(t_671) + t_673 = torch.add(t_672, t_667) + t_674 = F.relu(t_673) + t_675 = self.n_Conv_137(t_674) + t_676 = F.relu(t_675) + t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0) + t_677 = self.n_Conv_138(t_676_padded) + t_678 = F.relu(t_677) + t_679 = self.n_Conv_139(t_678) + t_680 = torch.add(t_679, t_674) + t_681 = F.relu(t_680) + t_682 = self.n_Conv_140(t_681) + t_683 = F.relu(t_682) + t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0) + t_684 = self.n_Conv_141(t_683_padded) + t_685 = F.relu(t_684) + t_686 = self.n_Conv_142(t_685) + t_687 = torch.add(t_686, t_681) + t_688 = F.relu(t_687) + t_689 = self.n_Conv_143(t_688) + t_690 = F.relu(t_689) + t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0) + t_691 = self.n_Conv_144(t_690_padded) + t_692 = F.relu(t_691) + t_693 = self.n_Conv_145(t_692) + t_694 = torch.add(t_693, t_688) + t_695 = F.relu(t_694) + t_696 = self.n_Conv_146(t_695) + t_697 = F.relu(t_696) + t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0) + t_698 = self.n_Conv_147(t_697_padded) + t_699 = F.relu(t_698) + t_700 = self.n_Conv_148(t_699) + t_701 = torch.add(t_700, t_695) + t_702 = F.relu(t_701) + t_703 = self.n_Conv_149(t_702) + t_704 = F.relu(t_703) + t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0) + t_705 = self.n_Conv_150(t_704_padded) + t_706 = F.relu(t_705) + t_707 = self.n_Conv_151(t_706) + t_708 = torch.add(t_707, t_702) + t_709 = F.relu(t_708) + t_710 = self.n_Conv_152(t_709) + t_711 = F.relu(t_710) + t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0) + t_712 = self.n_Conv_153(t_711_padded) + t_713 = F.relu(t_712) + t_714 = self.n_Conv_154(t_713) + t_715 = torch.add(t_714, t_709) + t_716 = F.relu(t_715) + t_717 = self.n_Conv_155(t_716) + t_718 = F.relu(t_717) + t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0) + t_719 = self.n_Conv_156(t_718_padded) + t_720 = F.relu(t_719) + t_721 = self.n_Conv_157(t_720) + t_722 = torch.add(t_721, t_716) + t_723 = F.relu(t_722) + t_724 = self.n_Conv_158(t_723) + t_725 = self.n_Conv_159(t_723) + t_726 = F.relu(t_725) + t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0) + t_727 = self.n_Conv_160(t_726_padded) + t_728 = F.relu(t_727) + t_729 = self.n_Conv_161(t_728) + t_730 = torch.add(t_729, t_724) + t_731 = F.relu(t_730) + t_732 = self.n_Conv_162(t_731) + t_733 = F.relu(t_732) + t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0) + t_734 = self.n_Conv_163(t_733_padded) + t_735 = F.relu(t_734) + t_736 = self.n_Conv_164(t_735) + t_737 = torch.add(t_736, t_731) + t_738 = F.relu(t_737) + t_739 = self.n_Conv_165(t_738) + t_740 = F.relu(t_739) + t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0) + t_741 = self.n_Conv_166(t_740_padded) + t_742 = F.relu(t_741) + t_743 = self.n_Conv_167(t_742) + t_744 = torch.add(t_743, t_738) + t_745 = F.relu(t_744) + t_746 = self.n_Conv_168(t_745) + t_747 = self.n_Conv_169(t_745) + t_748 = F.relu(t_747) + t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0) + t_749 = self.n_Conv_170(t_748_padded) + t_750 = F.relu(t_749) + t_751 = self.n_Conv_171(t_750) + t_752 = torch.add(t_751, t_746) + t_753 = F.relu(t_752) + t_754 = self.n_Conv_172(t_753) + t_755 = F.relu(t_754) + t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0) + t_756 = self.n_Conv_173(t_755_padded) + t_757 = F.relu(t_756) + t_758 = self.n_Conv_174(t_757) + t_759 = torch.add(t_758, t_753) + t_760 = F.relu(t_759) + t_761 = self.n_Conv_175(t_760) + t_762 = F.relu(t_761) + t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0) + t_763 = self.n_Conv_176(t_762_padded) + t_764 = F.relu(t_763) + t_765 = self.n_Conv_177(t_764) + t_766 = torch.add(t_765, t_760) + t_767 = F.relu(t_766) + t_768 = self.n_Conv_178(t_767) + t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:]) + t_770 = torch.squeeze(t_769, 3) + t_770 = torch.squeeze(t_770, 2) + t_771 = torch.sigmoid(t_770) + return t_771 + + def load_state_dict(self, state_dict, **kwargs): + self.tags = state_dict.get('tags', []) + + super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) + diff --git a/modules/shared.py b/modules/shared.py index a4457305..c93ae2a3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,7 +55,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") -parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") +parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 488aa5b5..56b9b2eb 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -6,12 +6,10 @@ import sys import tqdm import time -from modules import shared, images +from modules import shared, images, deepbooru from modules.paths import models_path from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop -if cmd_opts.deepdanbooru: - import modules.deepbooru as deepbooru def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): @@ -20,9 +18,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce shared.interrogator.load() if process_caption_deepbooru: - db_opts = deepbooru.create_deepbooru_opts() - db_opts[deepbooru.OPT_INCLUDE_RANKS] = False - deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) + deepbooru.model.start() preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug) @@ -32,7 +28,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce shared.interrogator.send_blip_to_ram() if process_caption_deepbooru: - deepbooru.release_process() + deepbooru.model.stop() def listfiles(dirname): @@ -58,7 +54,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti if params.process_caption_deepbooru: if len(caption) > 0: caption += ", " - caption += deepbooru.get_tags_from_process(image) + caption += deepbooru.model.tag_multi(image) filename_part = params.src filename_part = os.path.splitext(filename_part)[0] diff --git a/modules/ui.py b/modules/ui.py index a5953fce..e6da1b2a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,14 +19,11 @@ import numpy as np from PIL import Image, PngImagePlugin -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts -if cmd_opts.deepdanbooru: - from modules.deepbooru import get_deepbooru_tags - import modules.codeformer_model import modules.generation_parameters_copypaste as parameters_copypaste import modules.gfpgan_model @@ -352,7 +349,7 @@ def interrogate(image): def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image) + prompt = deepbooru.model.tag(image) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From ce6911158b5b2f9cf79b405a1f368f875492044d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 26 Nov 2022 16:10:46 +0300 Subject: Add support Stable Diffusion 2.0 --- README.md | 21 +- launch.py | 12 +- modules/paths.py | 2 +- modules/sd_hijack.py | 297 +++--------------------- modules/sd_hijack_clip.py | 301 +++++++++++++++++++++++++ modules/sd_hijack_inpainting.py | 20 +- modules/sd_hijack_open_clip.py | 37 +++ modules/sd_samplers.py | 14 +- modules/shared.py | 34 ++- modules/textual_inversion/textual_inversion.py | 7 +- modules/ui.py | 13 +- requirements.txt | 1 + requirements_versions.txt | 1 + v1-inference.yaml | 70 ++++++ webui.py | 5 +- 15 files changed, 504 insertions(+), 331 deletions(-) create mode 100644 modules/sd_hijack_clip.py create mode 100644 modules/sd_hijack_open_clip.py create mode 100644 v1-inference.yaml (limited to 'modules/shared.py') diff --git a/README.md b/README.md index 5f5ab3aa..8a4ffade 100644 --- a/README.md +++ b/README.md @@ -84,26 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - -## Where are Aesthetic Gradients?!?! -Aesthetic Gradients are now an extension. You can install it using git: - -```commandline -git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients -``` - -After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart -the UI. The interface for Aesthetic Gradients should appear exactly the same as it was. - -## Where is History/Image browser?!?! -Image browser is now an extension. You can install it using git: - -```commandline -git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser -``` - -After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart -the UI. The interface for Image browser should appear exactly the same as it was. +- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/launch.py b/launch.py index d2f1055c..b1626cb5 100644 --- a/launch.py +++ b/launch.py @@ -134,18 +134,19 @@ def prepare_enviroment(): gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") + openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") + stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') - stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") + stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") - k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991") + k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -179,6 +180,9 @@ def prepare_enviroment(): if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") + if not is_installed("open_clip"): + run_pip(f"install {openclip_package}", "open_clip") + if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): @@ -196,7 +200,7 @@ def prepare_enviroment(): os.makedirs(dir_repos, exist_ok=True) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index 1e7a2fbc..4dd03a35 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -9,7 +9,7 @@ sys.path.insert(0, script_path) # search for directory of stable diffusion in following places sd_path = None -possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] +possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] for possible_sd_path in possible_sd_paths: if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): sd_path = os.path.abspath(possible_sd_path) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..d5243fd3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,18 +9,29 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import cmd_opts +from modules import sd_hijack_clip, sd_hijack_open_clip + from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.models.diffusion.ddim import ldm.models.diffusion.plms +import ldm.modules.encoders.modules attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# new memory efficient cross attention blocks do not support hypernets and we already +# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention +ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention +ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention + +# silence new console spam from SD2 +ldm.modules.attention.print = lambda *args: None +ldm.modules.diffusionmodules.model.print = lambda *args: None def apply_optimizations(): undo_optimizations() @@ -49,16 +60,11 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetworks import hypernetwork - - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward # this stops hypernets from working ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 - class StableDiffusionModelHijack: fixes = None @@ -70,10 +76,13 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: + m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model @@ -89,12 +98,15 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: + if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: - model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: + model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords: + m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped + m.cond_stage_model = m.cond_stage_model.wrapped self.apply_circular(False) self.layers = None @@ -114,261 +126,8 @@ class StableDiffusionModelHijack: def tokenize(self, text): _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) - - -class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): - def __init__(self, wrapped, hijack): - super().__init__() - self.wrapped = wrapped - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - self.token_mults = {} - - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] - for text, ident in tokens_with_parens: - mult = 1.0 - for c in text: - if c == '[': - mult /= 1.1 - if c == ']': - mult *= 1.1 - if c == '(': - mult *= 1.1 - if c == ')': - mult /= 1.1 - - if mult != 1.0: - self.token_mults[ident] = mult - - def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_end = self.wrapped.tokenizer.eos_token_id - - if opts.enable_emphasis: - parsed = prompt_parser.parse_prompt_attention(line) - else: - parsed = [[line, 1.0]] - - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] - - fixes = [] - remade_tokens = [] - multipliers = [] - last_comma = -1 - - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] - - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) - - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - - if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 - - cache = {} - batch_multipliers = [] - for line in texts: - if line in cache: - remade_tokens, fixes, multipliers = cache[line] - else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) - token_count = max(current_token_count, token_count) - - cache[line] = (remade_tokens, fixes, multipliers) - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def process_text_old(self, text): - id_start = self.wrapped.tokenizer.bos_token_id - id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - overflowing_words = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 - - cache = {} - batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) - - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 - - return z - - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - - tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - - if opts.CLIP_stop_at_last_layers > 1: - z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] - z = self.wrapped.transformer.text_model.final_layer_norm(z) - else: - z = outputs.last_hidden_state - - # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) - original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) - new_mean = z.mean() - z *= original_mean / new_mean + return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count) - return z class EmbeddingsWithFixes(torch.nn.Module): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py new file mode 100644 index 00000000..b451d1cf --- /dev/null +++ b/modules/sd_hijack_clip.py @@ -0,0 +1,301 @@ +import math + +import torch + +from modules import prompt_parser, devices +from modules.shared import opts + + +def get_target_prompt_token_count(token_count): + return math.ceil(max(token_count, 1) / 75) * 75 + + +class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + def __init__(self, wrapped, hijack): + super().__init__() + self.wrapped = wrapped + self.hijack = hijack + + def tokenize(self, texts): + raise NotImplementedError + + def encode_with_transformers(self, tokens): + raise NotImplementedError + + def encode_embedding_init_text(self, init_text, nvpt): + raise NotImplementedError + + def tokenize_line(self, line, used_custom_terms, hijack_comments): + if opts.enable_emphasis: + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + fixes = [] + remade_tokens = [] + multipliers = [] + last_comma = -1 + + for tokens, (text, weight) in zip(tokenized, parsed): + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [self.id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + + if embedding is None: + remade_tokens.append(token) + multipliers.append(weight) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [self.id_end] * rem + multipliers += [1.0] * rem + iteration += 1 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + token_count = len(remade_tokens) + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + + remade_tokens = remade_tokens + [self.id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add + + return remade_tokens, fixes, multipliers, token_count + + def process_text(self, texts): + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_multipliers = [] + for line in texts: + if line in cache: + remade_tokens, fixes, multipliers = cache[line] + else: + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + token_count = max(current_token_count, token_count) + + cache[line] = (remade_tokens, fixes, multipliers) + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def process_text_old(self, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + use_old = opts.use_old_emphasis_implementation + if use_old: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + else: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] + + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + tokens = [] + multipliers = [] + for j in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[j]) > 0: + tokens.append(remade_batch_tokens[j][:75]) + multipliers.append(batch_multipliers[j][:75]) + else: + tokens.append([self.id_end] * 75) + multipliers.append([1.0] * 75) + + z1 = self.process_tokens(tokens, multipliers) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(devices.device) + + if self.id_end != self.id_pad: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(self.id_end) + tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad + + z = self.encode_with_transformers(tokens) + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] + batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + original_mean = z.mean() + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z *= original_mean / new_mean + + return z + + +class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + self.tokenizer = wrapped.tokenizer + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + + self.token_mults = {} + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.bos_token_id + self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_pad = self.id_end + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.transformer.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 46714a4f..938f9a58 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -199,8 +199,8 @@ def sample_plms(self, @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): b, *_, device = *x.shape, x.device def get_model_output(x, t): @@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature @@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info): def do_inpainting_hijack(): - ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning + # most of this stuff seems to no longer be needed because it is already included into SD2.0 + # LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint + # p_sample_plms is needed because PLMS can't work with dicts as conditionings + # this file should be cleaned up later if weverything tuens out to work fine + + # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion - ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim - ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim + # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms - ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms - + # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py new file mode 100644 index 00000000..f733e852 --- /dev/null +++ b/modules/sd_hijack_open_clip.py @@ -0,0 +1,37 @@ +import open_clip.tokenizer +import torch + +from modules import sd_hijack_clip, devices +from modules.shared import opts + +tokenizer = open_clip.tokenizer._tokenizer + + +class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers + z = self.wrapped.encode_with_transformer(tokens) + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 4fe67854..4edd8c60 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -127,7 +127,8 @@ class InterruptedException(BaseException): class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) - self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms + self.is_plms = hasattr(self.sampler, 'p_sample_plms') + self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim self.mask = None self.nmask = None self.init_latent = None @@ -218,7 +219,6 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None - def adjust_steps_if_invalid(self, p, num_steps): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) @@ -227,7 +227,6 @@ class VanillaStableDiffusionSampler: return num_steps - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) steps = self.adjust_steps_if_invalid(p, steps) @@ -260,9 +259,10 @@ class VanillaStableDiffusionSampler: steps = self.adjust_steps_if_invalid(p, steps or p.steps) # Wrap the conditioning models with additional image conditioning for inpainting model + # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) @@ -350,7 +350,9 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) diff --git a/modules/shared.py b/modules/shared.py index c93ae2a3..8fb1387a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,17 +11,15 @@ import tqdm import modules.artists import modules.interrogate import modules.memmon -import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading -from modules.hypernetworks import hypernetwork +from modules import localization, sd_vae, extensions, script_loading from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) @@ -121,10 +119,12 @@ xformers_available = False config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) -hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) +hypernetworks = {} loaded_hypernetwork = None + def reload_hypernetworks(): + from modules.hypernetworks import hypernetwork global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) @@ -206,10 +206,11 @@ class State: if self.current_latent is None: return + import modules.sd_samplers if opts.show_progress_grid: - self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) + self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) else: - self.current_image = sd_samplers.sample_to_image(self.current_latent) + self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) self.current_image_sampling_step = self.sampling_step @@ -248,6 +249,21 @@ def options_section(section_identifier, options_dict): return options_dict +def list_checkpoint_tiles(): + import modules.sd_models + return modules.sd_models.checkpoint_tiles() + + +def refresh_checkpoints(): + import modules.sd_models + return modules.sd_models.list_models() + + +def list_samplers(): + import modules.sd_samplers + return modules.sd_samplers.all_samplers + + hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} options_templates = {} @@ -333,7 +349,7 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), @@ -385,7 +401,7 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5e4d8688..a273e663 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -64,7 +64,8 @@ class EmbeddingDatabase: self.word_embeddings[embedding.name] = embedding - ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working + ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] if first_id not in self.ids_lookup: @@ -155,13 +156,11 @@ class EmbeddingDatabase: def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model - embedding_layer = cond_model.wrapped.transformer.text_model.embeddings with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active - ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): diff --git a/modules/ui.py b/modules/ui.py index e6da1b2a..e5cb69d0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -478,9 +478,7 @@ def create_toprow(is_img2img): if is_img2img: with gr.Column(scale=1, elem_id="interrogate_col"): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - - if cmd_opts.deepdanbooru: - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): with gr.Row(): @@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) - if cmd_opts.deepdanbooru: - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], ) diff --git a/requirements.txt b/requirements.txt index 762db4f3..e4e5ec64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ kornia lark inflection GitPython +torchsde diff --git a/requirements_versions.txt b/requirements_versions.txt index 662ca684..8d557fe3 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -25,3 +25,4 @@ kornia==0.6.7 lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 +torchsde==0.2.5 diff --git a/v1-inference.yaml b/v1-inference.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/webui.py b/webui.py index c5e5fe75..23215d1e 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler, extensions, localization +from modules import shared, devices, sd_samplers, upscaler, extensions, localization import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -23,7 +23,6 @@ import modules.scripts import modules.sd_hijack import modules.sd_models import modules.sd_vae -import modules.shared as shared import modules.txt2img import modules.script_callbacks @@ -86,7 +85,7 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: -- cgit v1.2.3 From b006382784a2f0887317bb60ea49d19b50a5dc7e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 11:52:53 +0300 Subject: serve images from where they are saved instead of a temporary directory add an option to choose a different temporary directory in the UI add an option to cleanup the selected temporary directory at startup --- modules/images.py | 2 ++ modules/shared.py | 7 ++++++ modules/ui.py | 16 ------------- modules/ui_tempdir.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++ webui.py | 16 ++++++++----- 5 files changed, 82 insertions(+), 21 deletions(-) create mode 100644 modules/ui_tempdir.py (limited to 'modules/shared.py') diff --git a/modules/images.py b/modules/images.py index 26d5b7a9..8737ccff 100644 --- a/modules/images.py +++ b/modules/images.py @@ -524,6 +524,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: image.save(fullfn, quality=opts.jpeg_quality) + image.already_saved_as = fullfn + target_side_length = 4000 oversize = image.width > target_side_length or image.height > target_side_length if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024): diff --git a/modules/shared.py b/modules/shared.py index 8fb1387a..af975f54 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,6 +16,9 @@ import modules.devices as devices from modules import localization, sd_vae, extensions, script_loading from modules.paths import models_path, script_path, sd_path + +demo = None + sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() @@ -292,6 +295,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), + + "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), + "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { diff --git a/modules/ui.py b/modules/ui.py index c8b8fecd..ea925c40 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -157,22 +157,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def save_pil_to_file(pil_image, dir=None): - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in pil_image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True - - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) - return file_obj - - -# override save to file function so that it also writes PNG info -gr.processing_utils.save_pil_to_file = save_pil_to_file - def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py new file mode 100644 index 00000000..9c6d3a9d --- /dev/null +++ b/modules/ui_tempdir.py @@ -0,0 +1,62 @@ +import os +import tempfile +from collections import namedtuple + +import gradio as gr + +from PIL import PngImagePlugin + +from modules import shared + + +Savedfile = namedtuple("Savedfile", ["name"]) + + +def save_pil_to_file(pil_image, dir=None): + already_saved_as = getattr(pil_image, 'already_saved_as', None) + if already_saved_as: + shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))} + file_obj = Savedfile(already_saved_as) + return file_obj + + if shared.opts.temp_dir != "": + dir = shared.opts.temp_dir + + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + +def on_tmpdir_changed(): + if shared.opts.temp_dir == "" or shared.demo is None: + return + + os.makedirs(shared.opts.temp_dir, exist_ok=True) + + shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)} + + +def cleanup_tmpdr(): + temp_dir = shared.opts.temp_dir + if temp_dir == "" or not os.path.isdir(temp_dir): + return + + for root, dirs, files in os.walk(temp_dir, topdown=False): + for name in files: + _, extension = os.path.splitext(name) + if extension != ".png": + continue + + filename = os.path.join(root, name) + os.remove(filename) diff --git a/webui.py b/webui.py index 23215d1e..6b79dc55 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import shared, devices, sd_samplers, upscaler, extensions, localization +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -31,12 +31,14 @@ from modules import modelloader from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork + queue_lock = threading.Lock() if cmd_opts.server_name: server_name = cmd_opts.server_name else: server_name = "0.0.0.0" if cmd_opts.listen else None + def wrap_queued_call(func): def f(*args, **kwargs): with queue_lock: @@ -87,6 +89,7 @@ def initialize(): shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -149,9 +152,12 @@ def webui(): initialize() while 1: - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() + + shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - app, local_url, share_url = demo.launch( + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, @@ -178,9 +184,9 @@ def webui(): if launch_api: create_api(app) - modules.script_callbacks.app_started_callback(demo, app) + modules.script_callbacks.app_started_callback(shared.demo, app) - wait_on_server(demo) + wait_on_server(shared.demo) sd_samplers.set_samplers() -- cgit v1.2.3 From be2e6de94a5d40bff6d65497fd5ebc275b389f3f Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Dec 2022 11:34:16 -0800 Subject: Fix clip skip of 1 not being restored from prompts --- modules/generation_parameters_copypaste.py | 4 ++++ modules/shared.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 01980dca..44fe1a6c 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -184,6 +184,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model else: res[k] = v + # Missing CLIP skip means it was set to 1 (the default) + if "Clip skip" not in res: + res["Clip skip"] = "1" + return res diff --git a/modules/shared.py b/modules/shared.py index c36ee211..b4ecc7ca 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -371,7 +371,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop at last layers of CLIP model (CLIP skip)", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From c7af672186ec09a514f0e78aa21155264e56c130 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Dec 2022 09:41:39 +0300 Subject: more simple config option name plus mouseover hint for clip skip --- javascript/hints.js | 2 ++ modules/shared.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index ac417ff6..57db35be 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -94,6 +94,8 @@ titles = { "Add difference": "Result = A + (B - C) * M", "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", + + "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc." } diff --git a/modules/shared.py b/modules/shared.py index b4ecc7ca..42ec4120 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -371,7 +371,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop at last layers of CLIP model (CLIP skip)", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3