From d07cb46f34b3d9fe7a78b102f899ebef352ea56b Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 20 Oct 2022 23:58:52 +0800 Subject: inspiration pull request --- webui.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 177bef74..5923905f 100644 --- a/webui.py +++ b/webui.py @@ -72,6 +72,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) def initialize(): + if cmd_opts.ui_debug_mode: + class enmpty(): + name = None + shared.sd_upscalers = [enmpty()] + return modelloader.cleanup_models() modules.sd_models.setup_model() codeformer.setup_model(cmd_opts.codeformer_models_path) -- cgit v1.2.3 From 5f4fec307c14dd7f817244ffa92e8a4a64abed0b Mon Sep 17 00:00:00 2001 From: Stephen Date: Thu, 20 Oct 2022 11:32:17 -0400 Subject: [Bugfix][API] - Fix API arg in launch script --- webui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 177bef74..87589064 100644 --- a/webui.py +++ b/webui.py @@ -118,7 +118,8 @@ def api_only(): api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) -def webui(launch_api=False): +def webui(): + launch_api = cmd_opts.api initialize() while 1: @@ -158,4 +159,4 @@ if __name__ == "__main__": if cmd_opts.nowebui: api_only() else: - webui(cmd_opts.api) + webui() -- cgit v1.2.3 From bb0f1a2cdae3410a41d06ae878f56e29b8154c41 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 01:23:00 +0800 Subject: inspiration finished --- javascript/inspiration.js | 27 ++++--- modules/inspiration.py | 192 ++++++++++++++++++++++++++++++---------------- modules/shared.py | 6 ++ modules/ui.py | 2 +- webui.py | 3 +- 5 files changed, 151 insertions(+), 79 deletions(-) (limited to 'webui.py') diff --git a/javascript/inspiration.js b/javascript/inspiration.js index e1c0e114..791a80c9 100644 --- a/javascript/inspiration.js +++ b/javascript/inspiration.js @@ -1,25 +1,31 @@ function public_image_index_in_gallery(item, gallery){ + var imgs = gallery.querySelectorAll("img.h-full") var index; var i = 0; - gallery.querySelectorAll("img").forEach(function(e){ + imgs.forEach(function(e){ if (e == item) index = i; i += 1; }); + var num = imgs.length / 2 + index = (index < num) ? index : (index - num) return index; } -function inspiration_selected(name, types, name_list){ +function inspiration_selected(name, name_list){ var btn = gradioApp().getElementById("inspiration_select_button") - return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index"), types]; -} + return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index")]; +} +function inspiration_click_get_button(){ + gradioApp().getElementById("inspiration_get_button").click(); +} var inspiration_image_click = function(){ var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); - var btn = gradioApp().getElementById("inspiration_select_button") - btn.setAttribute("img-index", index) - setTimeout(function(btn){btn.click();}, 10, btn) + var btn = gradioApp().getElementById("inspiration_select_button"); + btn.setAttribute("img-index", index); + setTimeout(function(btn){btn.click();}, 10, btn); } - + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ var gallery = gradioApp().getElementById("inspiration_gallery") @@ -27,11 +33,10 @@ document.addEventListener("DOMContentLoaded", function() { var node = gallery.querySelector(".absolute.backdrop-blur.h-full") if (node) { node.style.display = "None"; //parentNode.removeChild(node) - } - + } gallery.querySelectorAll('img').forEach(function(e){ e.onclick = inspiration_image_click - }) + }); } diff --git a/modules/inspiration.py b/modules/inspiration.py index 456bfcb5..f72ebf3a 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -1,122 +1,182 @@ import os import random -import gradio -inspiration_path = "inspiration" -inspiration_system_path = os.path.join(inspiration_path, "system") -def read_name_list(file): +import gradio +from modules.shared import opts +inspiration_system_path = os.path.join(opts.inspiration_dir, "system") +def read_name_list(file, types=None, keyword=None): if not os.path.exists(file): return [] - f = open(file, "r") ret = [] + f = open(file, "r") line = f.readline() while len(line) > 0: line = line.rstrip("\n") - ret.append(line) - print(ret) + if types is not None: + dirname = os.path.split(line) + if dirname[0] in types and keyword in dirname[1]: + ret.append(line) + else: + ret.append(line) + line = f.readline() return ret def save_name_list(file, name): - print(file) - f = open(file, "a") - f.write(name + "\n") + with open(file, "a") as f: + f.write(name + "\n") -def get_inspiration_images(source, types): - path = os.path.join(inspiration_path , types) +def get_types_list(): + files = os.listdir(opts.inspiration_dir) + types = [] + for x in files: + path = os.path.join(opts.inspiration_dir, x) + if x[0] == ".": + continue + if not os.path.isdir(path): + continue + if path == inspiration_system_path: + continue + types.append(x) + return types + +def get_inspiration_images(source, types, keyword): + get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) - names = random.sample(names, 25) + names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) + names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - names = random.sample(names, 25) - elif source == "Exclude abandoned": - abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - all_names = os.listdir(path) - names = [] - while len(names) < 25: - name = random.choice(all_names) - if name not in abondened: - names.append(name) + names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + print(names) + names = random.sample(names, get_num) if len(names) > get_num else names + elif source == "Exclude abandoned": + abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + + if len(all_names) > get_num: + names = [] + while len(names) < get_num: + name = random.choice(all_names) + if name not in abandoned: + names.append(name) + else: + names = all_names else: - names = random.sample(os.listdir(path), 25) - names = random.sample(names, 25) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names image_list = [] for a in names: - image_path = os.path.join(path, a) + image_path = os.path.join(opts.inspiration_dir, a) images = os.listdir(image_path) - image_list.append(os.path.join(image_path, random.choice(images))) - return image_list, names + image_list.append((os.path.join(image_path, random.choice(images)), a)) + return image_list, names, "" -def select_click(index, types, name_list): +def select_click(index, name_list): name = name_list[int(index)] - path = os.path.join(inspiration_path, types, name) + path = os.path.join(opts.inspiration_dir, name) images = os.listdir(path) - return name, [os.path.join(path, x) for x in images] + return name, [os.path.join(path, x) for x in images], "" -def give_up_click(name, types): - file = os.path.join(inspiration_system_path, types + "_abandoned.txt") +def give_up_click(name): + file = os.path.join(inspiration_system_path, "abandoned.txt") name_list = read_name_list(file) if name not in name_list: save_name_list(file, name) + return "Added to abandoned list" -def collect_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") - print(file) +def collect_click(name): + file = os.path.join(inspiration_system_path, "faverites.txt") name_list = read_name_list(file) - print(name_list) if name not in name_list: save_name_list(file, name) + return "Added to faverite list" -def moveout_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") +def moveout_click(name, source): + if source == "Abandoned": + file = os.path.join(inspiration_system_path, "abandoned.txt") + if source == "Favorites": + file = os.path.join(inspiration_system_path, "faverites.txt") + else: + return None name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + os.remove(file) + with open(file, "a") as f: + for a in name_list: + if a != name: + f.write(a) + return "Moved out {name} from {source} list" def source_change(source): - if source == "Abandoned" or source == "Favorites": - return gradio.Button.update(visible=True, value=f"Move out {source}") + if source in ["Abandoned", "Favorites"]: + return gradio.update(visible=True), [] else: - return gradio.Button.update(visible=False) + return gradio.update(visible=False), [] +def add_to_prompt(name, prompt): + print(name, prompt) + name = os.path.basename(name) + return prompt + "," + name -def ui(gr, opts): +def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(inspiration_path) + flag = os.path.exists(opts.inspiration_dir) if flag: - types = os.listdir(inspiration_path) - types = [x for x in types if x != "system"] + types = get_types_list() flag = len(types) > 0 - if not flag: - os.mkdir(inspiration_path) + else: + os.makedirs(opts.inspiration_dir) + if not flag: gr.HTML(""" -
" +

To activate inspiration function, you need get "inspiration" images first.


+ You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
+ https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ download these files, and select these files in the "Create inspiration images" script UI
+ There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style +
I suggest at least four images for each


+

You can also download generated pictures from here:


+ https://huggingface.co/datasets/yfszzx/inspiration
+ unzip the file to the project directory of webui
+ and restart webui, and enjoy the joy of creation!
""") return inspiration if not os.path.exists(inspiration_system_path): os.mkdir(inspiration_system_path) - gallery, names = get_inspiration_images("Exclude abandoned", types[0]) with gr.Row(): with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') with gr.Column(scale=1): - types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + print(types) + types = gr.CheckboxGroup(choices=types, value=types) + keyword = gr.Textbox("", label="Key word") with gr.Row(): source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - get_inspiration = gr.Button("Get inspiration") + get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") name = gr.Textbox(show_label=False, interactive=False) with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') - style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') - + style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') collect = gr.Button('Collect') - give_up = gr.Button("Don't show any more") + give_up = gr.Button("Don't show again") moveout = gr.Button("Move out", visible=False) - with gr.Row(): + warning = gr.HTML() + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State(names) - source.change(source_change, inputs=[source], outputs=[moveout]) - get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) - select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) - give_up.click(give_up_click, inputs=[name, types], outputs=None) - collect.click(collect_click, inputs=[name, types], outputs=None) + name_list = gr.State() + + get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword]) + source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) + source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) + give_up.click(give_up_click, inputs=[name], outputs=[warning]) + collect.click(collect_click, inputs=[name], outputs=[warning]) + moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) + send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) + send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) + send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) + send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) return inspiration diff --git a/modules/shared.py b/modules/shared.py index ae033710..564b1b8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -316,6 +316,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section(('inspiration', "Inspiration"), { + "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), + "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), + "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), +})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index 6a0a3c3b..b651eb9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1180,7 +1180,7 @@ def create_ui(wrap_gradio_gpu_call): } browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts) + inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): diff --git a/webui.py b/webui.py index 5923905f..5ccae715 100644 --- a/webui.py +++ b/webui.py @@ -72,6 +72,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) def initialize(): + modules.scripts.load_scripts(os.path.join(script_path, "scripts")) if cmd_opts.ui_debug_mode: class enmpty(): name = None @@ -84,7 +85,7 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -- cgit v1.2.3 From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 12:23:45 +0300 Subject: removed aesthetic gradients as built-in added support for extensions --- .gitignore | 2 +- extensions/put extension here.txt | 0 modules/aesthetic_clip.py | 241 -------------------------------------- modules/images_history.py | 2 +- modules/img2img.py | 5 +- modules/processing.py | 35 ++++-- modules/script_callbacks.py | 42 +++++++ modules/scripts.py | 210 ++++++++++++++++++++++++--------- modules/sd_hijack.py | 1 - modules/sd_models.py | 7 +- modules/shared.py | 19 --- modules/txt2img.py | 5 +- modules/ui.py | 83 ++----------- webui.py | 7 +- 14 files changed, 249 insertions(+), 410 deletions(-) create mode 100644 extensions/put extension here.txt delete mode 100644 modules/aesthetic_clip.py create mode 100644 modules/script_callbacks.py (limited to 'webui.py') diff --git a/.gitignore b/.gitignore index f9c3357c..2f1e08ed 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion -.vscode \ No newline at end of file +.vscode diff --git a/extensions/put extension here.txt b/extensions/put extension here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py deleted file mode 100644 index 8c828541..00000000 --- a/modules/aesthetic_clip.py +++ /dev/null @@ -1,241 +0,0 @@ -import copy -import itertools -import os -from pathlib import Path -import html -import gc - -import gradio as gr -import torch -from PIL import Image -from torch import optim - -from modules import shared -from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer -from tqdm.auto import tqdm, trange -from modules.shared import opts, device - - -def get_all_images_in_folder(folder): - return [os.path.join(folder, f) for f in os.listdir(folder) if - os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] - - -def check_is_valid_image_file(filename): - return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) - - -def batched(dataset, total, n=1): - for ndx in range(0, total, n): - yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] - - -def iter_to_batched(iterable, n=1): - it = iter(iterable) - while True: - chunk = tuple(itertools.islice(it, n)) - if not chunk: - return - yield chunk - - -def create_ui(): - import modules.ui - - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!", open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", - value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', - placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', - placeholder="This text is used to rotate the feature space of the imgs embs", - value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, - value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - - return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative - - -aesthetic_clip_model = None - - -def aesthetic_clip(): - global aesthetic_clip_model - - if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: - aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) - aesthetic_clip_model.cpu() - - return aesthetic_clip_model - - -def generate_imgs_embd(name, folder, batch_size): - model = aesthetic_clip().to(device) - processor = CLIPProcessor.from_pretrained(model.name_or_path) - - with torch.no_grad(): - embs = [] - for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), - desc=f"Generating embeddings for {name}"): - if shared.state.interrupted: - break - inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) - outputs = model.get_image_features(**inputs).cpu() - embs.append(torch.clone(outputs)) - inputs.to("cpu") - del inputs, outputs - - embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) - - # The generated embedding will be located here - path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") - torch.save(embs, path) - - model.cpu() - del processor - del embs - gc.collect() - torch.cuda.empty_cache() - res = f""" - Done generating embedding for {name}! - Aesthetic embedding saved to {html.escape(path)} - """ - shared.update_aesthetic_embeddings() - return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), \ - gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), - label="Imgs embedding", - value="None"), res, "" - - -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - -class AestheticCLIP: - def __init__(self): - self.skip = False - self.aesthetic_steps = 0 - self.aesthetic_weight = 0 - self.aesthetic_lr = 0 - self.slerp = False - self.aesthetic_text_negative = "" - self.aesthetic_slerp_angle = 0 - self.aesthetic_imgs_text = "" - - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) - - def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - if self.image_embs_name is not None: - p.extra_generation_params.update({ - "Aesthetic LR": aesthetic_lr, - "Aesthetic weight": aesthetic_weight, - "Aesthetic steps": aesthetic_steps, - "Aesthetic embedding": self.image_embs_name, - "Aesthetic slerp": aesthetic_slerp, - "Aesthetic text": aesthetic_imgs_text, - "Aesthetic text negative": aesthetic_text_negative, - "Aesthetic slerp angle": aesthetic_slerp_angle, - }) - - def set_skip(self, skip): - self.skip = skip - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - self.image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - - def __call__(self, z, remade_batch_tokens): - if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: - tokenizer = shared.sd_model.cond_stage_model.tokenizer - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(aesthetic_clip()).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - gc.collect() - torch.cuda.empty_cache() - zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - - return z diff --git a/modules/images_history.py b/modules/images_history.py index 78fd0543..bc5cf11f 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -310,7 +310,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): forward = gr.Button('Prev batch') backward = gr.Button('Next batch') with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) diff --git a/modules/img2img.py b/modules/img2img.py index eea5199b..8d9f7cf9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index ff1ec4c9..372489f7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -104,6 +104,12 @@ class StableDiffusionProcessing(): self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + self.scripts = None + self.script_args = None + self.all_prompts = None + self.all_seeds = None + self.all_subseeds = None + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - all_prompts = p.prompt + p.all_prompts = p.prompt else: - all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_prompts = p.batch_size * p.n_iter * [p.prompt] if type(seed) == list: - all_seeds = seed + p.all_seeds = seed else: - all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] if type(subseed) == list: - all_subseeds = subseed + p.all_subseeds = subseed else: - all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] + p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() + if p.scripts is not None: + p.scripts.run_alwayson_scripts(p) + infotexts = [] output_images = [] with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): - p.init(all_prompts, all_seeds, all_subseeds) + p.init(p.all_prompts, p.all_seeds, p.all_subseeds) if state.job_count == -1: state.job_count = p.n_iter @@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break - prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] - seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] - subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] + subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] if (len(prompts) == 0): break @@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py new file mode 100644 index 00000000..866b7acd --- /dev/null +++ b/modules/script_callbacks.py @@ -0,0 +1,42 @@ + +callbacks_model_loaded = [] +callbacks_ui_tabs = [] + + +def clear_callbacks(): + callbacks_model_loaded.clear() + callbacks_ui_tabs.clear() + + +def model_loaded_callback(sd_model): + for callback in callbacks_model_loaded: + callback(sd_model) + + +def ui_tabs_callback(): + res = [] + + for callback in callbacks_ui_tabs: + res += callback() or [] + + return res + + +def on_model_loaded(callback): + """register a function to be called when the stable diffusion model is created; the model is + passed as an argument""" + callbacks_model_loaded.append(callback) + + +def on_ui_tabs(callback): + """register a function to be called when the UI is creating new tabs. + The function must either return a None, which means no new tabs to be added, or a list, where + each element is a tuple: + (gradio_component, title, elem_id) + + gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks) + title is tab text displayed to user in the UI + elem_id is HTML id for the tab + """ + callbacks_ui_tabs.append(callback) + diff --git a/modules/scripts.py b/modules/scripts.py index 1039fa9c..65f25f49 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,86 +1,153 @@ import os import sys import traceback +from collections import namedtuple import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared +from modules import shared, paths, script_callbacks + +AlwaysVisible = object() + class Script: filename = None args_from = None args_to = None + alwayson = False + + infotext_fields = None + """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when + parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example + """ - # The title of the script. This is what will be displayed in the dropdown menu. def title(self): + """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" + raise NotImplementedError() - # How the script is displayed in the UI. See https://gradio.app/docs/#components - # for the different UI components you can use and how to create them. - # Most UI components can return a value, such as a boolean for a checkbox. - # The returned values are passed to the run method as parameters. def ui(self, is_img2img): + """this function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be an array of all components that are used in processing. + Values of those returned componenbts will be passed to run() and process() functions. + """ + pass - # Determines when the script should be shown in the dropdown menu via the - # returned value. As an example: - # is_img2img is True if the current tab is img2img, and False if it is txt2img. - # Thus, return is_img2img to only show the script on the img2img tab. def show(self, is_img2img): + """ + is_img2img is True if this function is called for the img2img interface, and Fasle otherwise + + This function should return: + - False if the script should not be shown in UI at all + - True if the script should be shown in UI if it's scelected in the scripts drowpdown + - script.AlwaysVisible if the script should be shown in UI at all times + """ + return True - # This is where the additional processing is implemented. The parameters include - # self, the model object "p" (a StableDiffusionProcessing class, see - # processing.py), and the parameters returned by the ui method. - # Custom functions can be defined here, and additional libraries can be imported - # to be used in processing. The return value should be a Processed object, which is - # what is returned by the process_images method. - def run(self, *args): + def run(self, p, *args): + """ + This function is called if the script has been selected in the script dropdown. + It must do all processing and return the Processed object with results, same as + one returned by processing.process_images. + + Usually the processing is done by calling the processing.process_images function. + + args contains all values returned by components from ui() + """ + raise NotImplementedError() - # The description method is currently unused. - # To add a description that appears when hovering over the title, amend the "titles" - # dict in script.js to include the script title (returned by title) as a key, and - # your description as the value. + def process(self, p, *args): + """ + This function is called before processing begins for AlwaysVisible scripts. + scripts. You can modify the processing object (p) here, inject hooks, etc. + """ + + pass + def describe(self): + """unused""" return "" +current_basedir = paths.script_path + + +def basedir(): + """returns the base directory for the current script. For scripts in the main scripts directory, + this is the main directory (where webui.py resides), and for scripts in extensions directory + (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic) + """ + return current_basedir + + scripts_data = [] +ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) + + +def list_scripts(scriptdirname, extension): + scripts_list = [] + + basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(basedir): + for filename in sorted(os.listdir(basedir)): + scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + for dirname in sorted(os.listdir(extdir)): + dirpath = os.path.join(extdir, dirname) + if not os.path.isdir(dirpath): + continue + for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) -def load_scripts(basedir): - if not os.path.exists(basedir): - return + scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] - for filename in sorted(os.listdir(basedir)): - path = os.path.join(basedir, filename) + return scripts_list - if os.path.splitext(path)[1].lower() != '.py': - continue - if not os.path.isfile(path): - continue +def load_scripts(): + global current_basedir + scripts_data.clear() + script_callbacks.clear_callbacks() + + scripts_list = list_scripts("scripts", ".py") + + syspath = sys.path + for scriptfile in sorted(scripts_list): try: - with open(path, "r", encoding="utf8") as file: + if scriptfile.basedir != paths.script_path: + sys.path = [scriptfile.basedir] + sys.path + current_basedir = scriptfile.basedir + + with open(scriptfile.path, "r", encoding="utf8") as file: text = file.read() from types import ModuleType - compiled = compile(text, path, 'exec') - module = ModuleType(filename) + compiled = compile(text, scriptfile.path, 'exec') + module = ModuleType(scriptfile.filename) exec(compiled, module.__dict__) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append((script_class, path)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) except Exception: - print(f"Error loading script: {filename}", file=sys.stderr) + print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + finally: + sys.path = syspath + current_basedir = paths.script_path + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: @@ -96,56 +163,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): class ScriptRunner: def __init__(self): self.scripts = [] + self.selectable_scripts = [] + self.alwayson_scripts = [] self.titles = [] + self.infotext_fields = [] def setup_ui(self, is_img2img): - for script_class, path in scripts_data: + for script_class, path, basedir in scripts_data: script = script_class() script.filename = path - if not script.show(is_img2img): - continue + visibility = script.show(is_img2img) - self.scripts.append(script) + if visibility == AlwaysVisible: + self.scripts.append(script) + self.alwayson_scripts.append(script) + script.alwayson = True - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] + elif visibility: + self.scripts.append(script) + self.selectable_scripts.append(script) - dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True - inputs = [dropdown] + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] + + inputs = [None] + inputs_alwayson = [True] - for script in self.scripts: + def create_script_ui(script, inputs, inputs_alwayson): script.args_from = len(inputs) script.args_to = len(inputs) controls = wrap_call(script.ui, script.filename, "ui", is_img2img) if controls is None: - continue + return for control in controls: control.custom_script_source = os.path.basename(script.filename) - control.visible = False + if not script.alwayson: + control.visible = False + + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields inputs += controls + inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) + for script in self.alwayson_scripts: + with gr.Group(): + create_script_ui(script, inputs, inputs_alwayson) + + dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown.save_to_config = True + inputs[0] = dropdown + + for script in self.selectable_scripts: + create_script_ui(script, inputs, inputs_alwayson) + def select_script(script_index): - if 0 < script_index <= len(self.scripts): - script = self.scripts[script_index-1] + if 0 < script_index <= len(self.selectable_scripts): + script = self.selectable_scripts[script_index-1] args_from = script.args_from args_to = script.args_to else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] def init_field(title): if title == 'None': return script_index = self.titles.index(title) - script = self.scripts[script_index] + script = self.selectable_scripts[script_index] for i in range(script.args_from, script.args_to): inputs[i].visible = True @@ -164,7 +255,7 @@ class ScriptRunner: if script_index == 0: return None - script = self.scripts[script_index-1] + script = self.selectable_scripts[script_index-1] if script is None: return None @@ -176,6 +267,15 @@ class ScriptRunner: return processed + def run_alwayson_scripts(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process(p, *script_args) + except Exception: + print(f"Error running alwayson script: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def reload_sources(self): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: @@ -197,19 +297,21 @@ class ScriptRunner: self.scripts[si].args_from = args_from self.scripts[si].args_to = args_to + scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + def reload_script_body_only(): scripts_txt2img.reload_sources() scripts_img2img.reload_sources() -def reload_scripts(basedir): +def reload_scripts(): global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1f8587d1..0f10828e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) - z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens diff --git a/modules/sd_models.py b/modules/sd_models.py index d99dbce8..f9b3063d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices +from modules import shared, modelloader, devices, script_callbacks from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -238,6 +238,9 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) sd_model.eval() + shared.sd_model = sd_model + + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") return sd_model @@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model(checkpoint_info) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/shared.py b/modules/shared.py index 0dbe360d..7d786f07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") @@ -109,21 +108,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - -os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True) -aesthetic_embeddings = {} - - -def update_aesthetic_embeddings(): - global aesthetic_embeddings - aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) - - -update_aesthetic_embeddings() - - def reload_hypernetworks(): global hypernetworks @@ -415,9 +399,6 @@ sd_model = None clip_model = None -from modules.aesthetic_clip import AestheticCLIP -aesthetic_clip = AestheticCLIP() - progress_print_out = sys.stdout diff --git a/modules/txt2img.py b/modules/txt2img.py index 1761cfa2..c9d5a090 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,7 +7,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 70a9cf10..c977482c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,10 +23,10 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models, localization +from modules import sd_hijack, sd_models, localization, script_callbacks from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags @@ -44,7 +44,6 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) @@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength, firstphase_width, firstphase_height, - aesthetic_lr, - aesthetic_weight, - aesthetic_steps, - aesthetic_imgs, - aesthetic_slerp, - aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative ] + custom_inputs, outputs=[ @@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), - (aesthetic_lr, "Aesthetic LR"), - (aesthetic_weight, "Aesthetic weight"), - (aesthetic_steps, "Aesthetic steps"), - (aesthetic_imgs, "Aesthetic embedding"), - (aesthetic_slerp, "Aesthetic slerp"), - (aesthetic_imgs_text, "Aesthetic text"), - (aesthetic_text_negative, "Aesthetic text negative"), - (aesthetic_slerp_angle, "Aesthetic slerp angle"), + *modules.scripts.scripts_txt2img.infotext_fields ] txt2img_preview_params = [ @@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) @@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, - aesthetic_lr_im, - aesthetic_weight_im, - aesthetic_steps_im, - aesthetic_imgs_im, - aesthetic_slerp_im, - aesthetic_imgs_text_im, - aesthetic_slerp_angle_im, - aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), - (aesthetic_lr_im, "Aesthetic LR"), - (aesthetic_weight_im, "Aesthetic weight"), - (aesthetic_steps_im, "Aesthetic steps"), - (aesthetic_imgs_im, "Aesthetic embedding"), - (aesthetic_slerp_im, "Aesthetic slerp"), - (aesthetic_imgs_text_im, "Aesthetic text"), - (aesthetic_text_negative_im, "Aesthetic text negative"), - (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), + *modules.scripts.scripts_img2img.infotext_fields ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1217,9 +1182,9 @@ def create_ui(wrap_gradio_gpu_call): ) #images history images_history_switch_dict = { - "fn":modules.generation_parameters_copypaste.connect_paste, - "t2i":txt2img_paste_fields, - "i2i":img2img_paste_fields + "fn": modules.generation_parameters_copypaste.connect_paste, + "t2i": txt2img_paste_fields, + "i2i": img2img_paste_fields } images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) @@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Tab(label="Create aesthetic images embedding"): - - new_embedding_name_ae = gr.Textbox(label="Name") - process_src_ae = gr.Textbox(label='Source directory') - batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') - with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call): ] ) - create_embedding_ae.click( - fn=aesthetic_clip.generate_imgs_embd, - inputs=[ - new_embedding_name_ae, - process_src_ae, - batch_ae - ], - outputs=[ - aesthetic_imgs, - aesthetic_imgs_im, - ti_output, - ti_outcome, - ] - ) - create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ @@ -1580,10 +1518,10 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + oldval = opts.data.get(key, None) if cmd_opts.hide_ui_dir_config and key in restricted_opts: return gr.update(value=oldval), opts.dumpjson() - oldval = opts.data.get(key, None) opts.data[key] = value if oldval != value: @@ -1692,9 +1630,12 @@ Requested path was: {f} (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), - (settings_interface, "Settings", "settings"), ] + interfaces += script_callbacks.ui_tabs_callback() + + interfaces += [(settings_interface, "Settings", "settings")] + with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: css = file.read() diff --git a/webui.py b/webui.py index 87589064..b1deca1b 100644 --- a/webui.py +++ b/webui.py @@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -79,9 +80,9 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + modules.scripts.load_scripts() - shared.sd_model = modules.sd_models.load_model() + modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) @@ -145,7 +146,7 @@ def webui(): sd_samplers.set_samplers() print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) print('Refreshing Model List') -- cgit v1.2.3 From 696cb33e50faf3f37859ebfba70fff902f46b8fb Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 23 Oct 2022 16:46:54 +0900 Subject: after initial launch, disable --autolaunch for subsequent restarts --- webui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'webui.py') diff --git a/webui.py b/webui.py index b1deca1b..a742c17d 100644 --- a/webui.py +++ b/webui.py @@ -135,6 +135,8 @@ def webui(): inbrowser=cmd_opts.autolaunch, prevent_thread_lock=True ) + # after initial launch, disable --autolaunch for subsequent restarts + cmd_opts.autolaunch = False app.add_middleware(GZipMiddleware, minimum_size=1000) -- cgit v1.2.3 From 876a96f0f9843382ebc8984db3de5d8af0e9ce4c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 24 Oct 2022 09:39:46 +0300 Subject: remove erroneous dir in the extension directory remove loading .js files from scripts dir (they go into javascript) load scripts after models, for scripts that depend on loaded models --- extensions/stable-diffusion-webui-inspiration | 1 - modules/ui.py | 2 +- webui.py | 11 ++++++----- 3 files changed, 7 insertions(+), 7 deletions(-) delete mode 160000 extensions/stable-diffusion-webui-inspiration (limited to 'webui.py') diff --git a/extensions/stable-diffusion-webui-inspiration b/extensions/stable-diffusion-webui-inspiration deleted file mode 160000 index a0b96664..00000000 --- a/extensions/stable-diffusion-webui-inspiration +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a0b96664d2524b87916ae463fbb65411b13a569b diff --git a/modules/ui.py b/modules/ui.py index a73b9ff0..03528968 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1885,7 +1885,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") - scripts_list += modules.scripts.list_scripts("scripts", ".js") + for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" diff --git a/webui.py b/webui.py index a0f3757f..ade7334b 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers +from modules import devices, sd_samplers, upscaler import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -73,12 +73,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): - modules.scripts.load_scripts() if cmd_opts.ui_debug_mode: - class enmpty(): - name = None - shared.sd_upscalers = [enmpty()] + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + modules.scripts.load_scripts() return + modelloader.cleanup_models() modules.sd_models.setup_model() codeformer.setup_model(cmd_opts.codeformer_models_path) @@ -86,6 +85,8 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() + modules.scripts.load_scripts() + modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) -- cgit v1.2.3 From 149784202cca8612b43629c601ee27cfda64e623 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 30 Oct 2022 09:10:22 +0300 Subject: rework #3722 to not introduce duplicate code --- modules/api/api.py | 43 +++++++++++++------------------------------ modules/shared.py | 22 +++++++++++++++++++--- webui.py | 19 +++---------------- 3 files changed, 35 insertions(+), 49 deletions(-) (limited to 'webui.py') diff --git a/modules/api/api.py b/modules/api/api.py index 5c5b210f..6c06d449 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo -# copy from wrap_gradio_gpu_call of webui.py -# because queue lock will be acquired in api handlers -# and time start needs to be set -# the function has been modified into two parts - -def before_gpu_call(): - devices.torch_gc() - - shared.state.sampling_step = 0 - shared.state.job_count = -1 - shared.state.job_no = 0 - shared.state.job_timestamp = shared.state.get_job_timestamp() - shared.state.current_latent = None - shared.state.current_image = None - shared.state.current_image_sampling_step = 0 - shared.state.skipped = False - shared.state.interrupted = False - shared.state.textinfo = None - shared.state.time_start = time.time() - -def after_gpu_call(): - shared.state.job = "" - shared.state.job_count = 0 - - devices.torch_gc() def upscaler_to_index(name: str): try: @@ -41,8 +16,10 @@ def upscaler_to_index(name: str): except: raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + def setUpscalers(req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) @@ -51,6 +28,7 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -78,10 +56,13 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -119,11 +100,13 @@ class Api: imgs = [img] * p.batch_size p.init_images = imgs - # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/shared.py b/modules/shared.py index f7b0990c..e4f163c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -144,9 +144,6 @@ class State: self.sampling_step = 0 self.current_image_sampling_step = 0 - def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? - def dict(self): obj = { "skipped": self.skipped, @@ -160,6 +157,25 @@ class State: return obj + def begin(self): + self.sampling_step = 0 + self.job_count = -1 + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + + devices.torch_gc() + + def end(self): + self.job = "" + self.job_count = 0 + + devices.torch_gc() state = State() diff --git a/webui.py b/webui.py index ade7334b..29530872 100644 --- a/webui.py +++ b/webui.py @@ -46,26 +46,13 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): - devices.torch_gc() - - shared.state.sampling_step = 0 - shared.state.job_count = -1 - shared.state.job_no = 0 - shared.state.job_timestamp = shared.state.get_job_timestamp() - shared.state.current_latent = None - shared.state.current_image = None - shared.state.current_image_sampling_step = 0 - shared.state.skipped = False - shared.state.interrupted = False - shared.state.textinfo = None + + shared.state.begin() with queue_lock: res = func(*args, **kwargs) - shared.state.job = "" - shared.state.job_count = 0 - - devices.torch_gc() + shared.state.end() return res -- cgit v1.2.3 From 423f22228306ae72d0480e25add9777c3c5d8fdf Mon Sep 17 00:00:00 2001 From: Maiko Tan Date: Sun, 30 Oct 2022 22:46:43 +0800 Subject: feat: add app started callback --- modules/script_callbacks.py | 15 +++++++++++++++ webui.py | 3 +++ 2 files changed, 18 insertions(+) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6ea58d61..f5509629 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -3,6 +3,8 @@ import traceback from collections import namedtuple import inspect +from fastapi import FastAPI +from gradio import Blocks def report_exception(c, job): print(f"Error executing callback {job} for {c.script}", file=sys.stderr) @@ -25,6 +27,7 @@ class ImageSaveParams: ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) +callbacks_app_started = [] callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -40,6 +43,14 @@ def clear_callbacks(): callbacks_image_saved.clear() +def app_started_callback(demo: Blocks, app: FastAPI): + for c in callbacks_app_started: + try: + c.callback(demo, app) + except Exception: + report_exception(c, 'app_started_callback') + + def model_loaded_callback(sd_model): for c in callbacks_model_loaded: try: @@ -91,6 +102,10 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) +def on_app_started(callback): + add_callback(callbacks_app_started, callback) + + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" diff --git a/webui.py b/webui.py index 29530872..13407e42 100644 --- a/webui.py +++ b/webui.py @@ -23,6 +23,7 @@ import modules.sd_hijack import modules.sd_models import modules.shared as shared import modules.txt2img +import modules.script_callbacks import modules.ui from modules import devices @@ -135,6 +136,8 @@ def webui(): if (launch_api): create_api(app) + modules.script_callbacks.app_started_callback(demo, app) + wait_on_server(demo) sd_samplers.set_samplers() -- cgit v1.2.3 From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- modules/sd_models.py | 31 +++++-------- modules/sd_vae.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 8 ++-- webui.py | 5 +++ 4 files changed, 141 insertions(+), 24 deletions(-) create mode 100644 modules/sd_vae.py (limited to 'webui.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..91ad4b5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks +from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - -def load_model_weights(model, checkpoint_info): +def load_model_weights(model, checkpoint_info, force=False): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if checkpoint_info not in checkpoints_loaded: + if force or checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path - - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict) - + sd_vae.load_vae(model, checkpoint_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: @@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None): +def load_model(checkpoint_info=None, force=False): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -234,7 +223,7 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -252,16 +241,16 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model, info=None, force=False): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info) + load_model(checkpoint_info, force=force) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py new file mode 100644 index 00000000..82764e55 --- /dev/null +++ b/modules/sd_vae.py @@ -0,0 +1,121 @@ +import torch +import os +from collections import namedtuple +from modules import shared, devices +from modules.paths import models_path +import glob + +model_dir = "Stable-diffusion" +model_path = os.path.abspath(os.path.join(models_path, model_dir)) +vae_dir = "VAE" +vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} +default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_list = ["auto", "None"] +default_vae_values = [default_vae_dict[x] for x in default_vae_list] +vae_dict = dict(default_vae_dict) +vae_list = list(default_vae_list) +first_load = True + +def get_filename(filepath): + return os.path.splitext(os.path.basename(filepath))[0] + +def refresh_vae_list(vae_path=vae_path, model_path=model_path): + global vae_dict, vae_list + res = {} + candidates = [ + *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) + ] + if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): + candidates.append(shared.cmd_opts.vae_path) + for filepath in candidates: + name = get_filename(filepath) + res[name] = filepath + vae_list.clear() + vae_list.extend(default_vae_list) + vae_list.extend(list(res.keys())) + vae_dict.clear() + vae_dict.update(default_vae_dict) + vae_dict.update(res) + return vae_list + +def load_vae(model, checkpoint_file, vae_file="auto"): + global first_load, vae_dict, vae_list + # save_settings = False + + # if vae_file argument is provided, it takes priority + if vae_file and vae_file not in default_vae_list: + if not os.path.isfile(vae_file): + vae_file = "auto" + # save_settings = True + print("VAE provided as function argument doesn't exist") + # for the first load, if vae-path is provided, it takes priority and failure is reported + if first_load and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + # save_settings = True + # print("Using VAE provided as command line argument") + else: + print("VAE provided as command line argument doesn't exist") + # else, we load from settings + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + # vae-path cmd arg takes priority for auto + if vae_file == "auto" and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + print("Using VAE provided as command line argument") + # if still not found, try look for ".vae.pt" beside model + model_path = os.path.splitext(checkpoint_file)[0] + if vae_file == "auto": + vae_file_try = model_path + ".vae.pt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # if still not found, try look for ".vae.ckpt" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.ckpt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # No more fallbacks for auto + if vae_file == "auto": + vae_file = None + # Last check, just because + if vae_file and not os.path.exists(vae_file): + vae_file = None + + if vae_file: + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + model.first_stage_model.load_state_dict(vae_dict_1) + + # If vae used is not in dict, update it + # It will be removed on refresh though + if vae_file is not None: + vae_opt = get_filename(vae_file) + if vae_opt not in vae_dict: + vae_dict[vae_opt] = vae_file + vae_list.append(vae_opt) + + """ + # Save current VAE to VAE settings, maybe? will it work? + if save_settings: + if vae_file is None: + vae_opt = "None" + + # shared.opts.sd_vae = vae_opt + """ + + first_load = False + model.first_stage_model.to(devices.dtype_vae) diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..06440ac4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization +from modules import sd_samplers, sd_models, localization, sd_vae from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -295,6 +295,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), @@ -407,11 +408,12 @@ class Options: if bad_settings > 0: print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) - def onchange(self, key, func): + def onchange(self, key, func, call=True): item = self.data_labels.get(key) item.onchange = func - func() + if call: + func() def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} diff --git a/webui.py b/webui.py index 29530872..27949f3d 100644 --- a/webui.py +++ b/webui.py @@ -21,6 +21,7 @@ import modules.paths import modules.scripts import modules.sd_hijack import modules.sd_models +import modules.sd_vae import modules.shared as shared import modules.txt2img @@ -74,8 +75,12 @@ def initialize(): modules.scripts.load_scripts() + modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) + # I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram, + # so for now this reloads the whole model too, and no cache + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model, force=True)), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From 910a097ae2ed78a62101951f1b87137f9e1baaea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 17:36:45 +0300 Subject: add initial version of the extensions tab fix broken Restart Gradio button --- javascript/extensions.js | 24 +++++ modules/extensions.py | 83 +++++++++++++++ modules/generation_parameters_copypaste.py | 5 + modules/scripts.py | 21 +--- modules/shared.py | 10 +- modules/ui.py | 16 ++- modules/ui_extensions.py | 162 +++++++++++++++++++++++++++++ style.css | 22 +++- webui.py | 20 ++-- 9 files changed, 333 insertions(+), 30 deletions(-) create mode 100644 javascript/extensions.js create mode 100644 modules/extensions.py create mode 100644 modules/ui_extensions.py (limited to 'webui.py') diff --git a/javascript/extensions.js b/javascript/extensions.js new file mode 100644 index 00000000..86f5336d --- /dev/null +++ b/javascript/extensions.js @@ -0,0 +1,24 @@ + +function extensions_apply(_, _){ + disable = [] + update = [] + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + + if(x.name.startsWith("update_") && x.checked) + update.push(x.name.substr(7)) + }) + + restart_reload() + + return [JSON.stringify(disable), JSON.stringify(update)] +} + +function extensions_check(){ + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ + x.innerHTML = "Loading..." + }) + + return [] +} \ No newline at end of file diff --git a/modules/extensions.py b/modules/extensions.py new file mode 100644 index 00000000..8d6ae848 --- /dev/null +++ b/modules/extensions.py @@ -0,0 +1,83 @@ +import os +import sys +import traceback + +import git + +from modules import paths, shared + + +extensions = [] +extensions_dir = os.path.join(paths.script_path, "extensions") + + +def active(): + return [x for x in extensions if x.enabled] + + +class Extension: + def __init__(self, name, path, enabled=True): + self.name = name + self.path = path + self.enabled = enabled + self.status = '' + self.can_update = False + + repo = None + try: + if os.path.exists(os.path.join(path, ".git")): + repo = git.Repo(path) + except Exception: + print(f"Error reading github repository info from {path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + if repo is None or repo.bare: + self.remote = None + else: + self.remote = next(repo.remote().urls, None) + self.status = 'unknown' + + def list_files(self, subdir, extension): + from modules import scripts + + dirpath = os.path.join(self.path, subdir) + if not os.path.isdir(dirpath): + return [] + + res = [] + for filename in sorted(os.listdir(dirpath)): + res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename))) + + res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + + return res + + def check_updates(self): + repo = git.Repo(self.path) + for fetch in repo.remote().fetch("--dry-run"): + if fetch.flags != fetch.HEAD_UPTODATE: + self.can_update = True + self.status = "behind" + return + + self.can_update = False + self.status = "latest" + + def pull(self): + repo = git.Repo(self.path) + repo.remotes.origin.pull() + + +def list_extensions(): + extensions.clear() + + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + path = os.path.join(extensions_dir, dirname) + if not os.path.isdir(path): + continue + + extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) + extensions.append(extension) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index df70c728..985ec95e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -17,6 +17,11 @@ paste_fields = {} bind_list = [] +def reset(): + paste_fields.clear() + bind_list.clear() + + def quote(text): if ',' not in str(text): return text diff --git a/modules/scripts.py b/modules/scripts.py index 96e44bfd..533db45c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks +from modules import shared, paths, script_callbacks, extensions AlwaysVisible = object() @@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension): for filename in sorted(os.listdir(basedir)): scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - for dirname in sorted(os.listdir(extdir)): - dirpath = os.path.join(extdir, dirname) - scriptdirpath = os.path.join(dirpath, scriptdirname) - - if not os.path.isdir(scriptdirpath): - continue - - for filename in sorted(os.listdir(scriptdirpath)): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) + for ext in extensions.active(): + scripts_list += ext.list_files(scriptdirname, extension) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension): def list_files_with_name(filename): res = [] - dirs = [paths.script_path] - - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + dirs = [paths.script_path] + [ext.path for ext in extensions.active()] for dirpath in dirs: if not os.path.isdir(dirpath): diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..cce87081 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -132,6 +132,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + need_restart = False def skip(self): self.skipped = True @@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable those extensions"), +})) + +options_templates.update() + class Options: data = None @@ -365,8 +372,9 @@ class Options: def __setattr__(self, key, value): if self.data is not None: - if key in self.data: + if key in self.data or key in self.data_labels: self.data[key] = value + return return super(Options, self).__setattr__(key, value) diff --git a/modules/ui.py b/modules/ui.py index 5055ca64..2c15abb7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin -from modules import sd_hijack, sd_models, localization, script_callbacks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img + parameters_copypaste.reset() with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call): column = None with gr.Row(elem_id="settings").style(equal_height=False): for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None - if previous_section != item.section: + if previous_section != item.section and not section_must_be_skipped: if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if column is not None: column.__exit__() @@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call): if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) else: component = create_setting_component(k) component_dict[k] = component @@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call): def request_restart(): shared.state.interrupt() - settings_interface.gradio_ref.do_restart = True + shared.state.need_restart = True restart_gradio.click( + fn=request_restart, inputs=[], outputs=[], @@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call): interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - settings_interface.gradio_ref = demo - parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.run_bind() diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py new file mode 100644 index 00000000..b7d747dc --- /dev/null +++ b/modules/ui_extensions.py @@ -0,0 +1,162 @@ +import json +import os.path +import shutil +import sys +import time +import traceback + +import git + +import gradio as gr +import html + +from modules import extensions, shared, paths + + +def apply_and_restart(disable_list, update_list): + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + update = json.loads(update_list) + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + + update = set(update) + + for ext in extensions.extensions: + if ext.name not in update: + continue + + try: + ext.pull() + except Exception: + print(f"Error pulling updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + shared.state.interrupt() + shared.state.need_restart = True + + +def check_updates(): + for ext in extensions.extensions: + if ext.remote is None: + continue + + try: + ext.check_updates() + except Exception: + print(f"Error checking updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return extension_table() + + +def extension_table(): + code = f""" + + + + + + + + + + """ + + for ext in extensions.extensions: + if ext.can_update: + ext_status = f"""""" + else: + ext_status = ext.status + + code += f""" + + + + {ext_status} + + """ + + code += """ + +
ExtensionURLUpdate
{html.escape(ext.remote or '')}
+ """ + + return code + + +def install_extension_from_url(dirname, url): + assert url, 'No URL specified' + + if dirname is None or dirname == "": + *parts, last_part = url.split('/') + last_part = last_part.replace(".git", "") + + dirname = last_part + + target_dir = os.path.join(extensions.extensions_dir, dirname) + assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' + + assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed' + + tmpdir = os.path.join(paths.script_path, "tmp", dirname) + + try: + shutil.rmtree(tmpdir, True) + + repo = git.Repo.clone_from(url, tmpdir) + repo.remote().fetch() + + os.rename(tmpdir, target_dir) + + extensions.list_extensions() + return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] + finally: + shutil.rmtree(tmpdir, True) + + +def create_ui(): + import modules.ui + + with gr.Blocks(analytics_enabled=False) as ui: + with gr.Tabs(elem_id="tabs_extensions") as tabs: + with gr.TabItem("Installed"): + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False) + + with gr.Row(): + apply = gr.Button(value="Apply and restart UI", variant="primary") + check = gr.Button(value="Check for updates") + + extensions_table = gr.HTML(lambda: extension_table()) + + apply.click( + fn=apply_and_restart, + _js="extensions_apply", + inputs=[extensions_disabled_list, extensions_update_list], + outputs=[], + ) + + check.click( + fn=check_updates, + _js="extensions_check", + inputs=[], + outputs=[extensions_table], + ) + + with gr.TabItem("Install from URL"): + install_url = gr.Text(label="URL for extension's git repository") + install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") + intall_button = gr.Button(value="Install", variant="primary") + intall_result = gr.HTML(elem_id="extension_install_result") + + intall_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), + inputs=[install_dirname, install_url], + outputs=[extensions_table, intall_result], + ) + + return ui diff --git a/style.css b/style.css index 8b2211b1..859c3933 100644 --- a/style.css +++ b/style.css @@ -530,6 +530,26 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h min-height: 480px !important; } +/* Extensions */ + +#extensions{ + border-collapse: collapse; +} + +#extensions td, #extensions th{ + border: 1px solid #ccc; + padding: 0.25em 0.5em; +} + +#extensions input[type="checkbox"]{ + margin-right: 0.5em; +} + +#tab_extensions button{ + max-width: 16em; +} + + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running @@ -607,4 +627,4 @@ Then, you will need to add the RTL counterpart only if needed in the rtl section right: unset; left: 0.5em; } -} +} \ No newline at end of file diff --git a/webui.py b/webui.py index 29530872..ad2eb236 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler +from modules import devices, sd_samplers, upscaler, extensions import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -60,6 +60,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): + extensions.list_extensions() + #for ext in extensions.extensions: + # print(ext.name, ext.path, ext.enabled, ext.remote) + #exit() + if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers modules.scripts.load_scripts() @@ -92,15 +97,18 @@ def create_api(app): api = Api(app, queue_lock) return api + def wait_on_server(demo=None): while 1: time.sleep(0.5) - if demo and getattr(demo, 'do_restart', False): + if shared.state.need_restart: + shared.state.need_restart = False time.sleep(0.5) demo.close() time.sleep(0.5) break + def api_only(): initialize() @@ -132,14 +140,16 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) - if (launch_api): + if launch_api: create_api(app) wait_on_server(demo) sd_samplers.set_samplers() - print('Reloading Custom Scripts') + print('Reloading extensions') + extensions.list_extensions() + print('Reloading custom scripts') modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) @@ -148,8 +158,6 @@ def webui(): print('Restarting Gradio') - -task = [] if __name__ == "__main__": if cmd_opts.nowebui: api_only() -- cgit v1.2.3 From 58cc03edd0fe8c7e64297bcfe51111caaafabfd7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 18:40:47 +0300 Subject: fix scripts I broke with the extension tab changes --- modules/extensions.py | 2 +- webui.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) (limited to 'webui.py') diff --git a/modules/extensions.py b/modules/extensions.py index 8d6ae848..897af96e 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -46,7 +46,7 @@ class Extension: res = [] for filename in sorted(os.listdir(dirpath)): - res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename))) + res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename))) res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] diff --git a/webui.py b/webui.py index ad2eb236..6ff95dc4 100644 --- a/webui.py +++ b/webui.py @@ -61,9 +61,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): extensions.list_extensions() - #for ext in extensions.extensions: - # print(ext.name, ext.path, ext.enabled, ext.remote) - #exit() if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers -- cgit v1.2.3 From af758e97fa2c4c853042f121af4e974be01e6696 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 1 Nov 2022 04:01:49 -0300 Subject: Unload sd_model before loading the other --- modules/lowvram.py | 21 +++++++++++++-------- modules/processing.py | 3 +++ modules/sd_hijack.py | 4 ++++ modules/sd_models.py | 14 +++++++++++++- webui.py | 2 +- 5 files changed, 34 insertions(+), 10 deletions(-) (limited to 'webui.py') diff --git a/modules/lowvram.py b/modules/lowvram.py index f327c3df..a4652cb1 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): # see below for register_forward_pre_hook; # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is # useless here, and we just replace those methods - def first_stage_model_encode_wrap(self, encoder, x): - send_me_to_gpu(self, None) - return encoder(x) - def first_stage_model_decode_wrap(self, decoder, z): - send_me_to_gpu(self, None) - return decoder(z) + first_stage_model = sd_model.first_stage_model + first_stage_model_encode = sd_model.first_stage_model.encode + first_stage_model_decode = sd_model.first_stage_model.decode + + def first_stage_model_encode_wrap(x): + send_me_to_gpu(first_stage_model, None) + return first_stage_model_encode(x) + + def first_stage_model_decode_wrap(z): + send_me_to_gpu(first_stage_model, None) + return first_stage_model_decode(z) # remove three big modules, cond, first_stage, and unet from the model and then # send the model to GPU. Then put modules back. the modules will be in CPU. @@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): # register hooks for those the first two models sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) - sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) - sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) + sd_model.first_stage_model.encode = first_stage_model_encode_wrap + sd_model.first_stage_model.decode = first_stage_model_decode_wrap parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model if use_medvram: diff --git a/modules/processing.py b/modules/processing.py index b1df4918..57d3a523 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess(p, res) + p.sd_model = None + p.sampler = None + return res diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0f10828e..bc49d235 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -94,6 +94,10 @@ class StableDiffusionModelHijack: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + self.layers = None + self.circular_enabled = False + self.clip = None + def apply_circular(self, enable): if self.circular_enabled == enable: return diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..90007da3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,6 +1,7 @@ import collections import os.path import sys +import gc from collections import namedtuple import torch import re @@ -220,6 +221,12 @@ def load_model(checkpoint_info=None): if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") + if shared.sd_model: + sd_hijack.model_hijack.undo_hijack(shared.sd_model) + shared.sd_model = None + gc.collect() + devices.torch_gc() + sd_config = OmegaConf.load(checkpoint_info.config) if should_hijack_inpainting(checkpoint_info): @@ -233,6 +240,7 @@ def load_model(checkpoint_info=None): checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) do_inpainting_hijack() + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -252,14 +260,18 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() + if not sd_model: + sd_model = shared.sd_model + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) return shared.sd_model diff --git a/webui.py b/webui.py index 6ff95dc4..9c393e55 100644 --- a/webui.py +++ b/webui.py @@ -77,7 +77,7 @@ def initialize(): modules.scripts.load_scripts() modules.sd_models.load_model() - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 12:51:46 +0700 Subject: Reload VAE without reloading sd checkpoint --- modules/sd_models.py | 15 ++++---- modules/sd_vae.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++---- webui.py | 4 +-- 3 files changed, 98 insertions(+), 18 deletions(-) (limited to 'webui.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ab85b65..883639d1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - checkpoint_key = (checkpoint_info, vae_file) + checkpoint_key = checkpoint_info if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, vae_file) - model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: + # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU + else: vae_name = sd_vae.get_filename(vae_file) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") @@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.load_vae(model, vae_file) + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack @@ -254,14 +253,14 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model=None, info=None, force=False): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e9239326..78e14e8a 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,26 +1,65 @@ import torch import os from collections import namedtuple -from modules import shared, devices +from modules import shared, devices, script_callbacks from modules.paths import models_path import glob + model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) vae_dir = "VAE" vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + default_vae_dict = {"auto": "auto", "None": "None"} default_vae_list = ["auto", "None"] + + default_vae_values = [default_vae_dict[x] for x in default_vae_list] vae_dict = dict(default_vae_dict) vae_list = list(default_vae_list) first_load = True + +base_vae = None +loaded_vae_file = None +checkpoint_info = None + + +def get_base_vae(model): + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: + return base_vae + return None + + +def store_base_vae(model): + global base_vae, checkpoint_info + if checkpoint_info != model.sd_checkpoint_info: + base_vae = model.first_stage_model.state_dict().copy() + checkpoint_info = model.sd_checkpoint_info + + +def delete_base_vae(): + global base_vae, checkpoint_info + base_vae = None + checkpoint_info = None + + +def restore_base_vae(model): + global base_vae, checkpoint_info + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + load_vae_dict(model, base_vae) + delete_base_vae() + + def get_filename(filepath): return os.path.splitext(os.path.basename(filepath))[0] + def refresh_vae_list(vae_path=vae_path, model_path=model_path): global vae_dict, vae_list res = {} @@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_dict.update(res) return vae_list + def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list # save_settings = False @@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"): return vae_file -def load_vae(model, vae_file): - global first_load, vae_dict, vae_list + +def load_vae(model, vae_file=None): + global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False if vae_file: print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict_1) + load_vae_dict(model, vae_dict_1) - # If vae used is not in dict, update it - # It will be removed on refresh though - if vae_file is not None: + # If vae used is not in dict, update it + # It will be removed on refresh though vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) + loaded_vae_file = vae_file + """ # Save current VAE to VAE settings, maybe? will it work? if save_settings: @@ -124,4 +166,45 @@ def load_vae(model, vae_file): """ first_load = False + + +# don't call this from outside +def load_vae_dict(model, vae_dict_1=None): + if vae_dict_1: + store_base_vae(model) + model.first_stage_model.load_state_dict(vae_dict_1) + else: + restore_base_vae() model.first_stage_model.to(devices.dtype_vae) + + +def reload_vae_weights(sd_model=None, vae_file="auto"): + from modules import lowvram, devices, sd_hijack + + if not sd_model: + sd_model = shared.sd_model + + checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_file = checkpoint_info.filename + vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if loaded_vae_file == vae_file: + return + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) + + sd_hijack.model_hijack.undo_hijack(sd_model) + + load_vae(sd_model, vae_file) + + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) + + print(f"VAE Weights loaded.") + return sd_model diff --git a/webui.py b/webui.py index 7cb4691b..034777a2 100644 --- a/webui.py +++ b/webui.py @@ -81,9 +81,7 @@ def initialize(): modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) - # I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram, - # so for now this reloads the whole model too - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(force=True)), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From dd2108fdac2ebf943d4ac3563a49202222b88acf Mon Sep 17 00:00:00 2001 From: Maiko Tan Date: Wed, 2 Nov 2022 15:04:35 +0800 Subject: fix: should invoke callback as well in api only mode --- modules/script_callbacks.py | 3 ++- webui.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index da88635b..c28e220e 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,6 +2,7 @@ import sys import traceback from collections import namedtuple import inspect +from typing import Optional from fastapi import FastAPI from gradio import Blocks @@ -62,7 +63,7 @@ def clear_callbacks(): callbacks_image_saved.clear() callbacks_cfg_denoiser.clear() -def app_started_callback(demo: Blocks, app: FastAPI): +def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callbacks_app_started: try: c.callback(demo, app) diff --git a/webui.py b/webui.py index 84e5c1fd..dc4223dc 100644 --- a/webui.py +++ b/webui.py @@ -114,6 +114,8 @@ def api_only(): app.add_middleware(GZipMiddleware, minimum_size=1000) api = create_api(app) + modules.script_callbacks.app_started_callback(None, app) + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) -- cgit v1.2.3 From 5f0117154382eb0e2547c72630256681673e353b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 10:07:29 +0300 Subject: shut down gradio's "everything allowed" CORS policy; I checked the main functionality to work with this, but if this breaks some exotic workflow, I'm sorry. --- README.md | 7 ++++--- webui.py | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/README.md b/README.md index 55c050d5..33508f31 100644 --- a/README.md +++ b/README.md @@ -155,14 +155,15 @@ The documentation was moved from this README over to the project's [wiki](https: - Swin2SR - https://github.com/mv-lab/swin2sr - LDSR - https://github.com/Hafiidz/latent-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion -- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. -- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) -- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). +- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. +- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - xformers - https://github.com/facebookresearch/xformers - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru +- Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/webui.py b/webui.py index 3b21c071..81df09dd 100644 --- a/webui.py +++ b/webui.py @@ -141,6 +141,12 @@ def webui(): # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for + # an attacker to trick the user into opening a malicious HTML page, which makes a request to the + # running web ui and do whatever the attcker wants, including installing an extension and + # runnnig its code. We disable this here. Suggested by RyotaK. + app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + app.add_middleware(GZipMiddleware, minimum_size=1000) if launch_api: -- cgit v1.2.3 From b8435e632f7ba0da12a2c8e9c788dda519279d24 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sat, 5 Nov 2022 02:36:47 +0800 Subject: add --cors-allow-origins cmd opt --- modules/shared.py | 7 ++++--- webui.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..e83cbcdf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) cmd_opts = parser.parse_args() restricted_opts = { @@ -147,9 +148,9 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.show_progress_every_n_steps == -1: self.do_set_current_image() - + self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 @@ -198,7 +199,7 @@ class State: return if self.current_latent is None: return - + if opts.show_progress_grid: self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) else: diff --git a/webui.py b/webui.py index 81df09dd..3788af0b 100644 --- a/webui.py +++ b/webui.py @@ -5,6 +5,7 @@ import importlib import signal import threading from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path @@ -93,6 +94,11 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) +def setup_cors(app): + if cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + + def create_api(app): from modules.api.api import Api api = Api(app, queue_lock) @@ -114,6 +120,7 @@ def api_only(): initialize() app = FastAPI() + setup_cors(app) app.add_middleware(GZipMiddleware, minimum_size=1000) api = create_api(app) @@ -147,6 +154,8 @@ def webui(): # runnnig its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) if launch_api: -- cgit v1.2.3 From e9a5562b9b27a1a4f9c282637b111cefd9727a41 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 5 Nov 2022 04:06:51 -0500 Subject: add support for tls (gradio tls options) --- modules/shared.py | 3 +++ webui.py | 22 ++++++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index 962115f6..7a20c3af 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,9 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) +parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) cmd_opts = parser.parse_args() restricted_opts = { diff --git a/webui.py b/webui.py index 81df09dd..d366f4ca 100644 --- a/webui.py +++ b/webui.py @@ -34,7 +34,7 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork queue_lock = threading.Lock() - +server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name def wrap_queued_call(func): def f(*args, **kwargs): @@ -85,6 +85,22 @@ def initialize(): shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print(f"path: '{cmd_opts.tls_keyfile}' {type(cmd_opts.tls_keyfile)}") + print(f"path: '{cmd_opts.tls_certfile}' {type(cmd_opts.tls_certfile)}") + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -131,8 +147,10 @@ def webui(): app, local_url, share_url = demo.launch( share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, + server_name=server_name, server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, debug=cmd_opts.gradio_debug, auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, inbrowser=cmd_opts.autolaunch, -- cgit v1.2.3 From a02bad570ef7718436369bb4e4aa5b8e0f1f5689 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 5 Nov 2022 04:14:21 -0500 Subject: rm dbg --- webui.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index d366f4ca..222dbeee 100644 --- a/webui.py +++ b/webui.py @@ -94,8 +94,6 @@ def initialize(): print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") except TypeError: cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None - print(f"path: '{cmd_opts.tls_keyfile}' {type(cmd_opts.tls_keyfile)}") - print(f"path: '{cmd_opts.tls_certfile}' {type(cmd_opts.tls_certfile)}") print("TLS setup invalid, running webui without TLS") else: print("Running with TLS") -- cgit v1.2.3 From a2a1a2f7270a865175f64475229838a8d64509ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 09:02:25 +0300 Subject: add ability to create extensions that add localizations --- javascript/ui.js | 2 ++ modules/localization.py | 6 ++++++ modules/scripts.py | 1 - modules/shared.py | 2 -- modules/ui.py | 3 +-- webui.py | 9 +++++---- 6 files changed, 14 insertions(+), 9 deletions(-) (limited to 'webui.py') diff --git a/javascript/ui.js b/javascript/ui.js index 7e116465..95cfd106 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -208,4 +208,6 @@ function update_token_counter(button_id) { function restart_reload(){ document.body.innerHTML='

Reloading...

'; setTimeout(function(){location.reload()},2000) + + return [] } diff --git a/modules/localization.py b/modules/localization.py index b1810cda..f6a6f2fb 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -3,6 +3,7 @@ import os import sys import traceback + localizations = {} @@ -16,6 +17,11 @@ def list_localizations(dirname): localizations[fn] = os.path.join(dirname, file) + from modules import scripts + for file in scripts.list_scripts("localizations", ".json"): + fn, ext = os.path.splitext(file.filename) + localizations[fn] = file.path + def localization_js(current_localization_name): fn = localizations.get(current_localization_name, None) diff --git a/modules/scripts.py b/modules/scripts.py index 366c90d7..637b2329 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,7 +3,6 @@ import sys import traceback from collections import namedtuple -import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing diff --git a/modules/shared.py b/modules/shared.py index 70b998ff..e8bacd3c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -localization.list_localizations(cmd_opts.localizations_dir) - def realesrgan_models_names(): import modules.realesrgan_model diff --git a/modules/ui.py b/modules/ui.py index 76ca9b07..23643c22 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,11 +1563,10 @@ def create_ui(wrap_gradio_gpu_call): shared.state.need_restart = True restart_gradio.click( - fn=request_restart, + _js='restart_reload', inputs=[], outputs=[], - _js='restart_reload' ) if column is not None: diff --git a/webui.py b/webui.py index a5a520f0..4342a962 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler, extensions +from modules import devices, sd_samplers, upscaler, extensions, localization import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -28,9 +28,7 @@ import modules.txt2img import modules.script_callbacks import modules.ui -from modules import devices from modules import modelloader -from modules.paths import script_path from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork @@ -64,6 +62,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers @@ -99,7 +98,6 @@ def initialize(): else: print("Running with TLS") - # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -185,6 +183,9 @@ def webui(): print('Reloading extensions') extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + print('Reloading custom scripts') modules.scripts.reload_scripts() print('Reloading modules: modules.ui') -- cgit v1.2.3 From e5b4e3f820cd09e751f1d168ab05d606d078a0d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 10:12:53 +0300 Subject: add tags to extensions, and ability to filter out tags list changed Settings keys in UI do not print VRAM/etc stats everywhere but in calls that use GPU --- modules/ui.py | 25 ++++++++++++---------- modules/ui_extensions.py | 55 ++++++++++++++++++++++++++++++++++++++---------- style.css | 5 +++++ webui.py | 2 +- 4 files changed, 64 insertions(+), 23 deletions(-) (limited to 'webui.py') diff --git a/modules/ui.py b/modules/ui.py index 23643c22..c946ad59 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None): gr.processing_utils.save_pil_to_file = save_pil_to_file -def wrap_gradio_call(func, extra_outputs=None): +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() @@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None): res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_s:.2f}s" - if (elapsed_m > 0): + if elapsed_m > 0: elapsed_text = f"{elapsed_m}m "+elapsed_text if run_memmon: @@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - return tuple(res) return f @@ -1436,7 +1439,7 @@ def create_ui(wrap_gradio_gpu_call): opts.reorder() def run_settings(*args): - changed = 0 + changed = [] for key, value, comp in zip(opts.data_labels.keys(), args, components): assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" @@ -1454,12 +1457,12 @@ def create_ui(wrap_gradio_gpu_call): if opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - changed += 1 + changed.append(key) try: opts.save(shared.config_filename) except RuntimeError: - return opts.dumpjson(), f'{changed} settings changed without save.' - return opts.dumpjson(), f'{changed} settings changed.' + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 8e0d41d5..02ab9643 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url): +def install_extension_from_index(url, hide_tags): ext_table, message = install_extension_from_url(None, url) - return refresh_available_extensions_from_data(), ext_table, message + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, ext_table, message -def refresh_available_extensions(url): + +def refresh_available_extensions(url, hide_tags): global available_extensions import urllib.request @@ -155,13 +157,25 @@ def refresh_available_extensions(url): available_extensions = json.loads(text) - return url, refresh_available_extensions_from_data(), '' + code, tags = refresh_available_extensions_from_data(hide_tags) + + return url, code, gr.CheckboxGroup.update(choices=tags), '' + + +def refresh_available_extensions_for_tags(hide_tags): + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, '' -def refresh_available_extensions_from_data(): + +def refresh_available_extensions_from_data(hide_tags): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + tags = available_extensions.get("tags", {}) + tags_to_hide = set(hide_tags) + hidden = 0 + code = f""" @@ -178,17 +192,24 @@ def refresh_available_extensions_from_data(): name = ext.get("name", "noname") url = ext.get("url", None) description = ext.get("description", "") + extension_tags = ext.get("tags", []) if url is None: continue + if len([x for x in extension_tags if x in tags_to_hide]) > 0: + hidden += 1 + continue + existing = installed_extension_urls.get(normalize_git_url(url), None) install_code = f"""""" + tags_text = ", ".join([f"{x}" for x in extension_tags]) + code += f""" - + @@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
{html.escape(name)}{html.escape(name)}
{tags_text}
{html.escape(description)} {install_code}
""" - return code + if hidden > 0: + code += f"

Extension hidden: {hidden}

" + + return code, list(tags) def create_ui(): @@ -238,21 +262,30 @@ def create_ui(): extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + with gr.Row(): + hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"]) + install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]), - inputs=[available_extensions_index], - outputs=[available_extensions_index, available_extensions_table, install_result], + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[available_extensions_index, hide_tags], + outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install], + inputs=[extension_to_install, hide_tags], outputs=[available_extensions_table, extensions_table, install_result], ) + hide_tags.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags], + outputs=[available_extensions_table, install_result] + ) + with gr.TabItem("Install from URL"): install_url = gr.Text(label="URL for extension's git repository") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") diff --git a/style.css b/style.css index a0382a8c..e2b71f25 100644 --- a/style.css +++ b/style.css @@ -563,6 +563,11 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h opacity: 0.5; } +.extension-tag{ + font-weight: bold; + font-size: 95%; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running diff --git a/webui.py b/webui.py index 4342a962..f4f1d74d 100644 --- a/webui.py +++ b/webui.py @@ -57,7 +57,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return res - return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) def initialize(): -- cgit v1.2.3 From a258fd60dbe2d68325339405a2aa72816d06d2fd Mon Sep 17 00:00:00 2001 From: Keavon Chambers Date: Mon, 7 Nov 2022 00:13:58 -0800 Subject: Add CORS-allow policy launch argument using regex --- modules/shared.py | 7 ++++--- webui.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c..55de286d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,12 +81,13 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") -parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) -parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) +parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) diff --git a/webui.py b/webui.py index f4f1d74d..066d94f7 100644 --- a/webui.py +++ b/webui.py @@ -107,8 +107,12 @@ def initialize(): def setup_cors(app): - if cmd_opts.cors_allow_origins: + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + elif cmd_opts.cors_allow_origins: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) def create_api(app): -- cgit v1.2.3 From 3405acc6a4dcef2b73782a04924a9a12422e54f0 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Mon, 14 Nov 2022 14:07:13 -0600 Subject: Give --server-name priority over --listen and add check for --server-name in addition to --share and --listen --- modules/shared.py | 2 +- webui.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..c628b580 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -106,7 +106,7 @@ restricted_opts = { "outdir_save", } -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) diff --git a/webui.py b/webui.py index f4f1d74d..fc776669 100644 --- a/webui.py +++ b/webui.py @@ -33,7 +33,10 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork queue_lock = threading.Lock() -server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name +if cmd_opts.server_name: + server_name = cmd_opts.server_name +else: + server_name = "0.0.0.0" if cmd_opts.listen else None def wrap_queued_call(func): def f(*args, **kwargs): -- cgit v1.2.3 From 0663706d4405b4f76ce653097f4f8989ee8b8684 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 13:47:03 +0700 Subject: Option to use selected VAE as default fallback instead of primary option --- modules/sd_vae.py | 25 ++++++++++++++++--------- modules/shared.py | 1 + webui.py | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) (limited to 'webui.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..0b5f0213 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -83,7 +83,19 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): return vae_list -def resolve_vae(checkpoint_file, vae_file="auto"): +def get_vae_from_settings(vae_file="auto"): + # else, we load from settings, if not set to be default + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + return vae_file + + +def resolve_vae(checkpoint_file=None, vae_file="auto"): global first_load, vae_dict, vae_list # if vae_file argument is provided, it takes priority, but not saved @@ -98,14 +110,9 @@ def resolve_vae(checkpoint_file, vae_file="auto"): shared.opts.data['sd_vae'] = get_filename(vae_file) else: print("VAE provided as command line argument doesn't exist") - # else, we load from settings - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print("Selected VAE doesn't exist") + # fallback to selector in settings, if vae selector not set to act as default fallback + if not shared.opts.sd_vae_as_default: + vae_file = get_vae_from_settings(vae_file) # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): diff --git a/modules/shared.py b/modules/shared.py index 17132e42..b84767f0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -336,6 +336,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae_as_default": OptionInfo(False, "Use selected VAE as default fallback instead"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/webui.py b/webui.py index f4f1d74d..2cd3bae9 100644 --- a/webui.py +++ b/webui.py @@ -82,6 +82,7 @@ def initialize(): modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From ce6911158b5b2f9cf79b405a1f368f875492044d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 26 Nov 2022 16:10:46 +0300 Subject: Add support Stable Diffusion 2.0 --- README.md | 21 +- launch.py | 12 +- modules/paths.py | 2 +- modules/sd_hijack.py | 297 +++--------------------- modules/sd_hijack_clip.py | 301 +++++++++++++++++++++++++ modules/sd_hijack_inpainting.py | 20 +- modules/sd_hijack_open_clip.py | 37 +++ modules/sd_samplers.py | 14 +- modules/shared.py | 34 ++- modules/textual_inversion/textual_inversion.py | 7 +- modules/ui.py | 13 +- requirements.txt | 1 + requirements_versions.txt | 1 + v1-inference.yaml | 70 ++++++ webui.py | 5 +- 15 files changed, 504 insertions(+), 331 deletions(-) create mode 100644 modules/sd_hijack_clip.py create mode 100644 modules/sd_hijack_open_clip.py create mode 100644 v1-inference.yaml (limited to 'webui.py') diff --git a/README.md b/README.md index 5f5ab3aa..8a4ffade 100644 --- a/README.md +++ b/README.md @@ -84,26 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - -## Where are Aesthetic Gradients?!?! -Aesthetic Gradients are now an extension. You can install it using git: - -```commandline -git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients -``` - -After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart -the UI. The interface for Aesthetic Gradients should appear exactly the same as it was. - -## Where is History/Image browser?!?! -Image browser is now an extension. You can install it using git: - -```commandline -git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser -``` - -After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart -the UI. The interface for Image browser should appear exactly the same as it was. +- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/launch.py b/launch.py index d2f1055c..b1626cb5 100644 --- a/launch.py +++ b/launch.py @@ -134,18 +134,19 @@ def prepare_enviroment(): gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") + openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") + stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') - stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") + stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") - k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991") + k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -179,6 +180,9 @@ def prepare_enviroment(): if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") + if not is_installed("open_clip"): + run_pip(f"install {openclip_package}", "open_clip") + if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): @@ -196,7 +200,7 @@ def prepare_enviroment(): os.makedirs(dir_repos, exist_ok=True) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index 1e7a2fbc..4dd03a35 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -9,7 +9,7 @@ sys.path.insert(0, script_path) # search for directory of stable diffusion in following places sd_path = None -possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] +possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] for possible_sd_path in possible_sd_paths: if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): sd_path = os.path.abspath(possible_sd_path) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..d5243fd3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,18 +9,29 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import cmd_opts +from modules import sd_hijack_clip, sd_hijack_open_clip + from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.models.diffusion.ddim import ldm.models.diffusion.plms +import ldm.modules.encoders.modules attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# new memory efficient cross attention blocks do not support hypernets and we already +# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention +ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention +ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention + +# silence new console spam from SD2 +ldm.modules.attention.print = lambda *args: None +ldm.modules.diffusionmodules.model.print = lambda *args: None def apply_optimizations(): undo_optimizations() @@ -49,16 +60,11 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetworks import hypernetwork - - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward # this stops hypernets from working ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 - class StableDiffusionModelHijack: fixes = None @@ -70,10 +76,13 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: + m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model @@ -89,12 +98,15 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: + if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: - model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: + model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords: + m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped + m.cond_stage_model = m.cond_stage_model.wrapped self.apply_circular(False) self.layers = None @@ -114,261 +126,8 @@ class StableDiffusionModelHijack: def tokenize(self, text): _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) - - -class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): - def __init__(self, wrapped, hijack): - super().__init__() - self.wrapped = wrapped - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - self.token_mults = {} - - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] - for text, ident in tokens_with_parens: - mult = 1.0 - for c in text: - if c == '[': - mult /= 1.1 - if c == ']': - mult *= 1.1 - if c == '(': - mult *= 1.1 - if c == ')': - mult /= 1.1 - - if mult != 1.0: - self.token_mults[ident] = mult - - def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_end = self.wrapped.tokenizer.eos_token_id - - if opts.enable_emphasis: - parsed = prompt_parser.parse_prompt_attention(line) - else: - parsed = [[line, 1.0]] - - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] - - fixes = [] - remade_tokens = [] - multipliers = [] - last_comma = -1 - - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] - - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) - - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - - if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 - - cache = {} - batch_multipliers = [] - for line in texts: - if line in cache: - remade_tokens, fixes, multipliers = cache[line] - else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) - token_count = max(current_token_count, token_count) - - cache[line] = (remade_tokens, fixes, multipliers) - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def process_text_old(self, text): - id_start = self.wrapped.tokenizer.bos_token_id - id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - overflowing_words = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 - - cache = {} - batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) - - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 - - return z - - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - - tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - - if opts.CLIP_stop_at_last_layers > 1: - z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] - z = self.wrapped.transformer.text_model.final_layer_norm(z) - else: - z = outputs.last_hidden_state - - # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) - original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) - new_mean = z.mean() - z *= original_mean / new_mean + return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count) - return z class EmbeddingsWithFixes(torch.nn.Module): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py new file mode 100644 index 00000000..b451d1cf --- /dev/null +++ b/modules/sd_hijack_clip.py @@ -0,0 +1,301 @@ +import math + +import torch + +from modules import prompt_parser, devices +from modules.shared import opts + + +def get_target_prompt_token_count(token_count): + return math.ceil(max(token_count, 1) / 75) * 75 + + +class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + def __init__(self, wrapped, hijack): + super().__init__() + self.wrapped = wrapped + self.hijack = hijack + + def tokenize(self, texts): + raise NotImplementedError + + def encode_with_transformers(self, tokens): + raise NotImplementedError + + def encode_embedding_init_text(self, init_text, nvpt): + raise NotImplementedError + + def tokenize_line(self, line, used_custom_terms, hijack_comments): + if opts.enable_emphasis: + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + fixes = [] + remade_tokens = [] + multipliers = [] + last_comma = -1 + + for tokens, (text, weight) in zip(tokenized, parsed): + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [self.id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + + if embedding is None: + remade_tokens.append(token) + multipliers.append(weight) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [self.id_end] * rem + multipliers += [1.0] * rem + iteration += 1 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + token_count = len(remade_tokens) + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + + remade_tokens = remade_tokens + [self.id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add + + return remade_tokens, fixes, multipliers, token_count + + def process_text(self, texts): + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_multipliers = [] + for line in texts: + if line in cache: + remade_tokens, fixes, multipliers = cache[line] + else: + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + token_count = max(current_token_count, token_count) + + cache[line] = (remade_tokens, fixes, multipliers) + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def process_text_old(self, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + use_old = opts.use_old_emphasis_implementation + if use_old: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + else: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] + + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + tokens = [] + multipliers = [] + for j in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[j]) > 0: + tokens.append(remade_batch_tokens[j][:75]) + multipliers.append(batch_multipliers[j][:75]) + else: + tokens.append([self.id_end] * 75) + multipliers.append([1.0] * 75) + + z1 = self.process_tokens(tokens, multipliers) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(devices.device) + + if self.id_end != self.id_pad: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(self.id_end) + tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad + + z = self.encode_with_transformers(tokens) + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] + batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + original_mean = z.mean() + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z *= original_mean / new_mean + + return z + + +class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + self.tokenizer = wrapped.tokenizer + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + + self.token_mults = {} + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.bos_token_id + self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_pad = self.id_end + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.transformer.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 46714a4f..938f9a58 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -199,8 +199,8 @@ def sample_plms(self, @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): b, *_, device = *x.shape, x.device def get_model_output(x, t): @@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature @@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info): def do_inpainting_hijack(): - ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning + # most of this stuff seems to no longer be needed because it is already included into SD2.0 + # LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint + # p_sample_plms is needed because PLMS can't work with dicts as conditionings + # this file should be cleaned up later if weverything tuens out to work fine + + # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion - ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim - ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim + # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms - ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms - + # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py new file mode 100644 index 00000000..f733e852 --- /dev/null +++ b/modules/sd_hijack_open_clip.py @@ -0,0 +1,37 @@ +import open_clip.tokenizer +import torch + +from modules import sd_hijack_clip, devices +from modules.shared import opts + +tokenizer = open_clip.tokenizer._tokenizer + + +class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers + z = self.wrapped.encode_with_transformer(tokens) + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 4fe67854..4edd8c60 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -127,7 +127,8 @@ class InterruptedException(BaseException): class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) - self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms + self.is_plms = hasattr(self.sampler, 'p_sample_plms') + self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim self.mask = None self.nmask = None self.init_latent = None @@ -218,7 +219,6 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None - def adjust_steps_if_invalid(self, p, num_steps): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) @@ -227,7 +227,6 @@ class VanillaStableDiffusionSampler: return num_steps - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) steps = self.adjust_steps_if_invalid(p, steps) @@ -260,9 +259,10 @@ class VanillaStableDiffusionSampler: steps = self.adjust_steps_if_invalid(p, steps or p.steps) # Wrap the conditioning models with additional image conditioning for inpainting model + # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) @@ -350,7 +350,9 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) diff --git a/modules/shared.py b/modules/shared.py index c93ae2a3..8fb1387a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,17 +11,15 @@ import tqdm import modules.artists import modules.interrogate import modules.memmon -import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading -from modules.hypernetworks import hypernetwork +from modules import localization, sd_vae, extensions, script_loading from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) @@ -121,10 +119,12 @@ xformers_available = False config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) -hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) +hypernetworks = {} loaded_hypernetwork = None + def reload_hypernetworks(): + from modules.hypernetworks import hypernetwork global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) @@ -206,10 +206,11 @@ class State: if self.current_latent is None: return + import modules.sd_samplers if opts.show_progress_grid: - self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) + self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) else: - self.current_image = sd_samplers.sample_to_image(self.current_latent) + self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) self.current_image_sampling_step = self.sampling_step @@ -248,6 +249,21 @@ def options_section(section_identifier, options_dict): return options_dict +def list_checkpoint_tiles(): + import modules.sd_models + return modules.sd_models.checkpoint_tiles() + + +def refresh_checkpoints(): + import modules.sd_models + return modules.sd_models.list_models() + + +def list_samplers(): + import modules.sd_samplers + return modules.sd_samplers.all_samplers + + hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} options_templates = {} @@ -333,7 +349,7 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), @@ -385,7 +401,7 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5e4d8688..a273e663 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -64,7 +64,8 @@ class EmbeddingDatabase: self.word_embeddings[embedding.name] = embedding - ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working + ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] if first_id not in self.ids_lookup: @@ -155,13 +156,11 @@ class EmbeddingDatabase: def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model - embedding_layer = cond_model.wrapped.transformer.text_model.embeddings with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active - ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): diff --git a/modules/ui.py b/modules/ui.py index e6da1b2a..e5cb69d0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -478,9 +478,7 @@ def create_toprow(is_img2img): if is_img2img: with gr.Column(scale=1, elem_id="interrogate_col"): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - - if cmd_opts.deepdanbooru: - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): with gr.Row(): @@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) - if cmd_opts.deepdanbooru: - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], ) diff --git a/requirements.txt b/requirements.txt index 762db4f3..e4e5ec64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ kornia lark inflection GitPython +torchsde diff --git a/requirements_versions.txt b/requirements_versions.txt index 662ca684..8d557fe3 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -25,3 +25,4 @@ kornia==0.6.7 lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 +torchsde==0.2.5 diff --git a/v1-inference.yaml b/v1-inference.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/webui.py b/webui.py index c5e5fe75..23215d1e 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler, extensions, localization +from modules import shared, devices, sd_samplers, upscaler, extensions, localization import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -23,7 +23,6 @@ import modules.scripts import modules.sd_hijack import modules.sd_models import modules.sd_vae -import modules.shared as shared import modules.txt2img import modules.script_callbacks @@ -86,7 +85,7 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: -- cgit v1.2.3 From b006382784a2f0887317bb60ea49d19b50a5dc7e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 11:52:53 +0300 Subject: serve images from where they are saved instead of a temporary directory add an option to choose a different temporary directory in the UI add an option to cleanup the selected temporary directory at startup --- modules/images.py | 2 ++ modules/shared.py | 7 ++++++ modules/ui.py | 16 ------------- modules/ui_tempdir.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++ webui.py | 16 ++++++++----- 5 files changed, 82 insertions(+), 21 deletions(-) create mode 100644 modules/ui_tempdir.py (limited to 'webui.py') diff --git a/modules/images.py b/modules/images.py index 26d5b7a9..8737ccff 100644 --- a/modules/images.py +++ b/modules/images.py @@ -524,6 +524,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: image.save(fullfn, quality=opts.jpeg_quality) + image.already_saved_as = fullfn + target_side_length = 4000 oversize = image.width > target_side_length or image.height > target_side_length if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024): diff --git a/modules/shared.py b/modules/shared.py index 8fb1387a..af975f54 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,6 +16,9 @@ import modules.devices as devices from modules import localization, sd_vae, extensions, script_loading from modules.paths import models_path, script_path, sd_path + +demo = None + sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() @@ -292,6 +295,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), + + "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), + "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { diff --git a/modules/ui.py b/modules/ui.py index c8b8fecd..ea925c40 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -157,22 +157,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def save_pil_to_file(pil_image, dir=None): - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in pil_image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True - - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) - return file_obj - - -# override save to file function so that it also writes PNG info -gr.processing_utils.save_pil_to_file = save_pil_to_file - def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py new file mode 100644 index 00000000..9c6d3a9d --- /dev/null +++ b/modules/ui_tempdir.py @@ -0,0 +1,62 @@ +import os +import tempfile +from collections import namedtuple + +import gradio as gr + +from PIL import PngImagePlugin + +from modules import shared + + +Savedfile = namedtuple("Savedfile", ["name"]) + + +def save_pil_to_file(pil_image, dir=None): + already_saved_as = getattr(pil_image, 'already_saved_as', None) + if already_saved_as: + shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))} + file_obj = Savedfile(already_saved_as) + return file_obj + + if shared.opts.temp_dir != "": + dir = shared.opts.temp_dir + + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + +def on_tmpdir_changed(): + if shared.opts.temp_dir == "" or shared.demo is None: + return + + os.makedirs(shared.opts.temp_dir, exist_ok=True) + + shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)} + + +def cleanup_tmpdr(): + temp_dir = shared.opts.temp_dir + if temp_dir == "" or not os.path.isdir(temp_dir): + return + + for root, dirs, files in os.walk(temp_dir, topdown=False): + for name in files: + _, extension = os.path.splitext(name) + if extension != ".png": + continue + + filename = os.path.join(root, name) + os.remove(filename) diff --git a/webui.py b/webui.py index 23215d1e..6b79dc55 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import shared, devices, sd_samplers, upscaler, extensions, localization +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -31,12 +31,14 @@ from modules import modelloader from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork + queue_lock = threading.Lock() if cmd_opts.server_name: server_name = cmd_opts.server_name else: server_name = "0.0.0.0" if cmd_opts.listen else None + def wrap_queued_call(func): def f(*args, **kwargs): with queue_lock: @@ -87,6 +89,7 @@ def initialize(): shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -149,9 +152,12 @@ def webui(): initialize() while 1: - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() + + shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - app, local_url, share_url = demo.launch( + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, @@ -178,9 +184,9 @@ def webui(): if launch_api: create_api(app) - modules.script_callbacks.app_started_callback(demo, app) + modules.script_callbacks.app_started_callback(shared.demo, app) - wait_on_server(demo) + wait_on_server(shared.demo) sd_samplers.set_samplers() -- cgit v1.2.3 From 0b5dcb3d7ce397ad38312dbfc70febe7bb42dcc3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 28 Nov 2022 09:00:10 +0300 Subject: fix an error that happens when you type into prompt while switching model, put queue stuff into separate file --- modules/call_queue.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++ modules/ui.py | 67 ++--------------------------------- webui.py | 30 ++-------------- 3 files changed, 104 insertions(+), 91 deletions(-) create mode 100644 modules/call_queue.py (limited to 'webui.py') diff --git a/modules/call_queue.py b/modules/call_queue.py new file mode 100644 index 00000000..4cd49533 --- /dev/null +++ b/modules/call_queue.py @@ -0,0 +1,98 @@ +import html +import sys +import threading +import traceback +import time + +from modules import shared + +queue_lock = threading.Lock() + + +def wrap_queued_call(func): + def f(*args, **kwargs): + with queue_lock: + res = func(*args, **kwargs) + + return res + + return f + + +def wrap_gradio_gpu_call(func, extra_outputs=None): + def f(*args, **kwargs): + + shared.state.begin() + + with queue_lock: + res = func(*args, **kwargs) + + shared.state.end() + + return res + + return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) + + +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): + run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats + if run_memmon: + shared.mem_mon.monitor() + t = time.perf_counter() + + try: + res = list(func(*args, **kwargs)) + except Exception as e: + # When printing out our debug argument list, do not print out more than a MB of text + max_debug_str_len = 131072 # (1024*1024)/8 + + print("Error completing request", file=sys.stderr) + argStr = f"Arguments: {str(args)} {str(kwargs)}" + print(argStr[:max_debug_str_len], file=sys.stderr) + if len(argStr) > max_debug_str_len: + print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) + + print(traceback.format_exc(), file=sys.stderr) + + shared.state.job = "" + shared.state.job_count = 0 + + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{html.escape(type(e).__name__+': '+str(e))}
"] + + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + + elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"{elapsed_s:.2f}s" + if elapsed_m > 0: + elapsed_text = f"{elapsed_m}m "+elapsed_text + + if run_memmon: + mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + + vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + else: + vram_html = '' + + # last item is always HTML + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + + return tuple(res) + + return f + diff --git a/modules/ui.py b/modules/ui.py index 446bee40..00809361 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -17,7 +17,7 @@ import gradio.routes import gradio.utils import numpy as np from PIL import Image, PngImagePlugin - +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules.paths import script_path @@ -158,67 +158,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): - def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats - if run_memmon: - shared.mem_mon.monitor() - t = time.perf_counter() - - try: - res = list(func(*args, **kwargs)) - except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {str(args)} {str(kwargs)}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) - - shared.state.job = "" - shared.state.job_count = 0 - - if extra_outputs_array is None: - extra_outputs_array = [None, ''] - - res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] - - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - - if not add_stats: - return tuple(res) - - elapsed = time.perf_counter() - t - elapsed_m = int(elapsed // 60) - elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" - if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text - - if run_memmon: - mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} - active_peak = mem_stats['active_peak'] - reserved_peak = mem_stats['reserved_peak'] - sys_peak = mem_stats['system_peak'] - sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) - - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" - else: - vram_html = '' - - # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - - return tuple(res) - - return f def calc_time_left(progress, threshold, label, force_display): @@ -666,7 +605,7 @@ Requested path was: {f} return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info -def create_ui(wrap_gradio_gpu_call): +def create_ui(): import modules.img2img import modules.txt2img @@ -826,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call): height, ] - token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) diff --git a/webui.py b/webui.py index 7a56bde8..16e7ec1a 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir @@ -32,38 +33,12 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork -queue_lock = threading.Lock() if cmd_opts.server_name: server_name = cmd_opts.server_name else: server_name = "0.0.0.0" if cmd_opts.listen else None -def wrap_queued_call(func): - def f(*args, **kwargs): - with queue_lock: - res = func(*args, **kwargs) - - return res - - return f - - -def wrap_gradio_gpu_call(func, extra_outputs=None): - def f(*args, **kwargs): - - shared.state.begin() - - with queue_lock: - res = func(*args, **kwargs) - - shared.state.end() - - return res - - return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) - - def initialize(): extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) @@ -159,7 +134,7 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() - shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, @@ -189,6 +164,7 @@ def webui(): create_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) + modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) -- cgit v1.2.3 From b6e5edd74657e3fd1fbd04f341b7a84625d4aa7a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Dec 2022 18:06:33 +0300 Subject: add built-in extension system add support for adding upscalers in extensions move LDSR, ScuNET and SwinIR to built-in extensions --- extensions-builtin/LDSR/ldsr_model_arch.py | 230 +++++ extensions-builtin/LDSR/preload.py | 6 + extensions-builtin/LDSR/scripts/ldsr_model.py | 63 ++ extensions-builtin/ScuNET/preload.py | 6 + extensions-builtin/ScuNET/scripts/scunet_model.py | 87 ++ extensions-builtin/ScuNET/scunet_model_arch.py | 265 ++++++ extensions-builtin/SwinIR/preload.py | 6 + extensions-builtin/SwinIR/scripts/swinir_model.py | 168 ++++ extensions-builtin/SwinIR/swinir_model_arch.py | 867 ++++++++++++++++++ extensions-builtin/SwinIR/swinir_model_arch_v2.py | 1017 +++++++++++++++++++++ modules/devices.py | 11 +- modules/extensions.py | 22 +- modules/ldsr_model.py | 54 -- modules/ldsr_model_arch.py | 230 ----- modules/modelloader.py | 20 +- modules/scunet_model.py | 87 -- modules/scunet_model_arch.py | 265 ------ modules/shared.py | 13 +- modules/swinir_model.py | 157 ---- modules/swinir_model_arch.py | 867 ------------------ modules/swinir_model_arch_v2.py | 1017 --------------------- modules/ui.py | 1 - modules/ui_extensions.py | 8 +- webui.py | 5 +- 24 files changed, 2761 insertions(+), 2711 deletions(-) create mode 100644 extensions-builtin/LDSR/ldsr_model_arch.py create mode 100644 extensions-builtin/LDSR/preload.py create mode 100644 extensions-builtin/LDSR/scripts/ldsr_model.py create mode 100644 extensions-builtin/ScuNET/preload.py create mode 100644 extensions-builtin/ScuNET/scripts/scunet_model.py create mode 100644 extensions-builtin/ScuNET/scunet_model_arch.py create mode 100644 extensions-builtin/SwinIR/preload.py create mode 100644 extensions-builtin/SwinIR/scripts/swinir_model.py create mode 100644 extensions-builtin/SwinIR/swinir_model_arch.py create mode 100644 extensions-builtin/SwinIR/swinir_model_arch_v2.py delete mode 100644 modules/ldsr_model.py delete mode 100644 modules/ldsr_model_arch.py delete mode 100644 modules/scunet_model.py delete mode 100644 modules/scunet_model_arch.py delete mode 100644 modules/swinir_model.py delete mode 100644 modules/swinir_model_arch.py delete mode 100644 modules/swinir_model_arch_v2.py (limited to 'webui.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py new file mode 100644 index 00000000..90e0a2f0 --- /dev/null +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -0,0 +1,230 @@ +import gc +import time +import warnings + +import numpy as np +import torch +import torchvision +from PIL import Image +from einops import rearrange, repeat +from omegaconf import OmegaConf + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config, ismap + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Create LDSR Class +class LDSR: + def load_model_from_config(self, half_attention): + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + model = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model.cuda() + if half_attention: + model = model.half() + + model.eval() + return {"model": model} + + def __init__(self, model_path, yaml_path): + self.modelPath = model_path + self.yamlPath = yaml_path + + @staticmethod + def run(model, selected_path, custom_steps, eta): + example = get_cond(selected_path) + + n_runs = 1 + guider = None + ckwargs = None + ddim_use_x0_pred = False + temperature = 1. + eta = eta + custom_shape = None + + height, width = example["image"].shape[1:3] + split_input = height >= 128 and width >= 128 + + if split_input: + ks = 128 + stride = 64 + vqf = 4 # + model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01} + else: + if hasattr(model, "split_input_params"): + delattr(model, "split_input_params") + + x_t = None + logs = None + for n in range(n_runs): + if custom_shape is not None: + x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) + x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) + + logs = make_convolutional_sample(example, model, + custom_steps=custom_steps, + eta=eta, quantize_x0=False, + custom_shape=custom_shape, + temperature=temperature, noise_dropout=0., + corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, + ddim_use_x0_pred=ddim_use_x0_pred + ) + return logs + + def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): + model = self.load_model_from_config(half_attention) + + # Run settings + diffusion_steps = int(steps) + eta = 1.0 + + down_sample_method = 'Lanczos' + + gc.collect() + torch.cuda.empty_cache() + + im_og = image + width_og, height_og = im_og.size + # If we can adjust the max upscale size, then the 4 below should be our variable + down_sample_rate = target_scale / 4 + wd = width_og * down_sample_rate + hd = height_og * down_sample_rate + width_downsampled_pre = int(np.ceil(wd)) + height_downsampled_pre = int(np.ceil(hd)) + + if down_sample_rate != 1: + print( + f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') + im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) + else: + print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") + + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts + pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size + im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) + + logs = self.run(model["model"], im_padded, diffusion_steps, eta) + + sample = logs["sample"] + sample = sample.detach().cpu() + sample = torch.clamp(sample, -1., 1.) + sample = (sample + 1.) / 2. * 255 + sample = sample.numpy().astype(np.uint8) + sample = np.transpose(sample, (0, 2, 3, 1)) + a = Image.fromarray(sample[0]) + + # remove padding + a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) + + del model + gc.collect() + torch.cuda.empty_cache() + return a + + +def get_cond(selected_path): + example = dict() + up_f = 4 + c = selected_path.convert('RGB') + c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) + c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], + antialias=True) + c_up = rearrange(c_up, '1 c h w -> 1 h w c') + c = rearrange(c, '1 c h w -> 1 h w c') + c = 2. * c - 1. + + c = c.to(torch.device("cuda")) + example["LR_image"] = c + example["image"] = c_up + + return example + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, + mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, + corrector_kwargs=None, x_t=None + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, + normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, + mask=mask, x0=x0, temperature=temperature, verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, x_t=x_t) + + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, + corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): + log = dict() + + z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, 'split_input_params') + and model.cond_stage_key == 'coordinates_bbox'), + return_original_cond=True) + + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + z0 = None + + log["input"] = x + log["reconstruction"] = xrec + + if ismap(xc): + log["original_conditioning"] = model.to_rgb(xc) + if hasattr(model, 'cond_stage_key'): + log[model.cond_stage_key] = model.to_rgb(xc) + + else: + log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_model: + log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_key == 'class_label': + log[model.cond_stage_key] = xc[model.cond_stage_key] + + with model.ema_scope("Plotting"): + t0 = time.time() + + sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, mask=None, x0=z0, + temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, + x_t=x_T) + t1 = time.time() + + if ddim_use_x0_pred: + sample = intermediates['pred_x0'][-1] + + x_sample = model.decode_first_stage(sample) + + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + log["sample"] = x_sample + log["time"] = t1 - t0 + + return log diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py new file mode 100644 index 00000000..d746007c --- /dev/null +++ b/extensions-builtin/LDSR/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR')) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py new file mode 100644 index 00000000..841ecba0 --- /dev/null +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -0,0 +1,63 @@ +import os +import sys +import traceback + +from basicsr.utils.download_util import load_file_from_url + +from modules.upscaler import Upscaler, UpscalerData +from ldsr_model_arch import LDSR +from modules import shared, script_callbacks + + +class UpscalerLDSR(Upscaler): + def __init__(self, user_path): + self.name = "LDSR" + self.user_path = user_path + self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" + self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + super().__init__() + scaler_data = UpscalerData("LDSR", None, self) + self.scalers = [scaler_data] + + def load_model(self, path: str): + # Remove incorrect project.yaml file if too big + yaml_path = os.path.join(self.model_path, "project.yaml") + old_model_path = os.path.join(self.model_path, "model.pth") + new_model_path = os.path.join(self.model_path, "model.ckpt") + if os.path.exists(yaml_path): + statinfo = os.stat(yaml_path) + if statinfo.st_size >= 10485760: + print("Removing invalid LDSR YAML file.") + os.remove(yaml_path) + if os.path.exists(old_model_path): + print("Renaming model from model.pth to model.ckpt") + os.rename(old_model_path, new_model_path) + model = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="model.ckpt", progress=True) + yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, + file_name="project.yaml", progress=True) + + try: + return LDSR(model, yaml) + + except Exception: + print("Error importing LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + def do_upscale(self, img, path): + ldsr = self.load_model(path) + if ldsr is None: + print("NO LDSR!") + return img + ddim_steps = shared.opts.ldsr_steps + return ldsr.super_resolution(img, ddim_steps, self.scale) + + +def on_ui_settings(): + import gradio as gr + + shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/ScuNET/preload.py b/extensions-builtin/ScuNET/preload.py new file mode 100644 index 00000000..f12c5b90 --- /dev/null +++ b/extensions-builtin/ScuNET/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET')) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py new file mode 100644 index 00000000..e0fbf3a3 --- /dev/null +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -0,0 +1,87 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import devices, modelloader +from scunet_model_arch import SCUNet as net + + +class UpscalerScuNET(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "ScuNET" + self.model_name = "ScuNET GAN" + self.model_name2 = "ScuNET PSNR" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" + self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pth"]) + scalers = [] + add_model2 = True + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + if name == self.model_name2 or file == self.model_url2: + add_model2 = False + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading ScuNET model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + if add_model2: + scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) + scalers.append(scaler_data2) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + + model = self.load_model(selected_file) + if model is None: + return img + + device = devices.get_device_for('scunet') + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(device) + + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + device = devices.get_device_for('scunet') + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: + print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) + return None + + model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + + return model + diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py new file mode 100644 index 00000000..43ca8d36 --- /dev/null +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_, DropPath + + +class WMSA(nn.Module): + """ Self-attention module in Swin Transformer + """ + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim ** -0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) + + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=.02) + self.relative_position_params = torch.nn.Parameter( + self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, + 2).transpose( + 0, 1)) + + def generate_mask(self, h, w, p, shift): + """ generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting square. + attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) + if self.type == 'W': + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') + return attn_mask + + def forward(self, x): + """ Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; + attn_mask: attention mask, fill -inf where the value is True; + Returns: + output: tensor shape [b h w c] + """ + if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) + h_windows = x.size(1) + w_windows = x.size(2) + # square validation + # assert h_windows == w_windows + + x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) + qkv = self.embedding_layer(x) + q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) + sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale + # Adding learnable relative embedding + sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') + # Using Attn Mask to distinguish different subwindows. + if self.type != 'W': + attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) + sim = sim.masked_fill_(attn_mask, float("-inf")) + + probs = nn.functional.softmax(sim, dim=-1) + output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) + output = rearrange(output, 'h b w p c -> b w p (h c)') + output = self.linear(output) + output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) + + if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), + dims=(1, 2)) + return output + + def relative_embedding(self): + cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) + relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 + # negative is allowed + return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] + + +class Block(nn.Module): + def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer Block + """ + super(Block, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + assert type in ['W', 'SW'] + self.type = type + if input_resolution <= window_size: + self.type = 'W' + + self.ln1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = nn.Sequential( + nn.Linear(input_dim, 4 * input_dim), + nn.GELU(), + nn.Linear(4 * input_dim, output_dim), + ) + + def forward(self, x): + x = x + self.drop_path(self.msa(self.ln1(x))) + x = x + self.drop_path(self.mlp(self.ln2(x))) + return x + + +class ConvTransBlock(nn.Module): + def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer and Conv Block + """ + super(ConvTransBlock, self).__init__() + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.head_dim = head_dim + self.window_size = window_size + self.drop_path = drop_path + self.type = type + self.input_resolution = input_resolution + + assert self.type in ['W', 'SW'] + if self.input_resolution <= self.window_size: + self.type = 'W' + + self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, + self.type, self.input_resolution) + self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + + self.conv_block = nn.Sequential( + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) + ) + + def forward(self, x): + conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) + conv_x = self.conv_block(conv_x) + conv_x + trans_x = Rearrange('b c h w -> b h w c')(trans_x) + trans_x = self.trans_block(trans_x) + trans_x = Rearrange('b h w c -> b c h w')(trans_x) + res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) + x = x + res + + return x + + +class SCUNet(nn.Module): + # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): + def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): + super(SCUNet, self).__init__() + if config is None: + config = [2, 2, 2, 2, 2, 2, 2] + self.config = config + self.dim = dim + self.head_dim = 32 + self.window_size = 8 + + # drop path rate for each layer + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] + + self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] + + begin = 0 + self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[0])] + \ + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] + + begin += config[0] + self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[1])] + \ + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] + + begin += config[1] + self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[2])] + \ + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] + + begin += config[2] + self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 8) + for i in range(config[3])] + + begin += config[3] + self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[4])] + + begin += config[4] + self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[5])] + + begin += config[5] + self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[6])] + + self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] + + self.m_head = nn.Sequential(*self.m_head) + self.m_down1 = nn.Sequential(*self.m_down1) + self.m_down2 = nn.Sequential(*self.m_down2) + self.m_down3 = nn.Sequential(*self.m_down3) + self.m_body = nn.Sequential(*self.m_body) + self.m_up3 = nn.Sequential(*self.m_up3) + self.m_up2 = nn.Sequential(*self.m_up2) + self.m_up1 = nn.Sequential(*self.m_up1) + self.m_tail = nn.Sequential(*self.m_tail) + # self.apply(self._init_weights) + + def forward(self, x0): + + h, w = x0.size()[-2:] + paddingBottom = int(np.ceil(h / 64) * 64 - h) + paddingRight = int(np.ceil(w / 64) * 64 - w) + x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) + + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[..., :h, :w] + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) \ No newline at end of file diff --git a/extensions-builtin/SwinIR/preload.py b/extensions-builtin/SwinIR/preload.py new file mode 100644 index 00000000..567e44bc --- /dev/null +++ b/extensions-builtin/SwinIR/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR')) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py new file mode 100644 index 00000000..782769e2 --- /dev/null +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -0,0 +1,168 @@ +import contextlib +import os + +import numpy as np +import torch +from PIL import Image +from basicsr.utils.download_util import load_file_from_url +from tqdm import tqdm + +from modules import modelloader, devices, script_callbacks, shared +from modules.shared import cmd_opts, opts +from swinir_model_arch import SwinIR as net +from swinir_model_arch_v2 import Swin2SR as net2 +from modules.upscaler import Upscaler, UpscalerData + + +device_swinir = devices.get_device_for('swinir') + + +class UpscalerSwinIR(Upscaler): + def __init__(self, dirname): + self.name = "SwinIR" + self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ + "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ + "-L_x4_GAN.pth " + self.model_name = "SwinIR 4x" + self.user_path = dirname + super().__init__() + scalers = [] + model_files = self.find_models(ext_filter=[".pt", ".pth"]) + for model in model_files: + if "http" in model: + name = self.model_name + else: + name = modelloader.friendly_name(model) + model_data = UpscalerData(name, model, self) + scalers.append(model_data) + self.scalers = scalers + + def do_upscale(self, img, model_file): + model = self.load_model(model_file) + if model is None: + return img + model = model.to(device_swinir, dtype=devices.dtype) + img = upscale(img, model) + try: + torch.cuda.empty_cache() + except: + pass + return img + + def load_model(self, path, scale=4): + if "http" in path: + dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None + if filename.endswith(".v2.pth"): + model = net2( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="1conv", + ) + params = None + else: + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + params = "params_ema" + + pretrained_model = torch.load(filename) + if params is not None: + model.load_state_dict(pretrained_model[params], strict=True) + else: + model.load_state_dict(pretrained_model, strict=True) + return model + + +def upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + window_size=8, + scale=4, +): + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + with torch.no_grad(), devices.autocast(): + _, _, h_old, w_old = img.size() + h_pad = (h_old // window_size + 1) * window_size - h_old + w_pad = (w_old // window_size + 1) * window_size - w_old + img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] + img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] + output = inference(img, model, tile, tile_overlap, window_size, scale) + output = output[..., : h_old * scale, : w_old * scale] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose( + output[[2, 1, 0], :, :], (1, 2, 0) + ) # CHW-RGB to HCW-BGR + output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 + return Image.fromarray(output, "RGB") + + +def inference(img, model, tile, tile_overlap, window_size, scale): + # test the image tile by tile + b, c, h, w = img.size() + tile = min(tile, h, w) + assert tile % window_size == 0, "tile size should be a multiple of window_size" + sf = scale + + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) + + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) + output = E.div_(W) + + return output + + +def on_ui_settings(): + import gradio as gr + + shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py new file mode 100644 index 00000000..863f42db --- /dev/null +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -0,0 +1,867 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py new file mode 100644 index 00000000..0e28ae6e --- /dev/null +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -0,0 +1,1017 @@ +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + #assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(Swin2SR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_hf: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] + + else: + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/modules/devices.py b/modules/devices.py index d6a76844..f8cffae1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -44,6 +44,15 @@ def get_optimal_device(): return cpu +def get_device_for(task): + from modules import shared + + if task in shared.cmd_opts.use_cpu: + return cpu + + return get_optimal_device() + + def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): @@ -67,7 +76,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None +device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/extensions.py b/modules/extensions.py index db9c4200..b522125c 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -8,6 +8,7 @@ from modules import paths, shared extensions = [] extensions_dir = os.path.join(paths.script_path, "extensions") +extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") def active(): @@ -15,12 +16,13 @@ def active(): class Extension: - def __init__(self, name, path, enabled=True): + def __init__(self, name, path, enabled=True, is_builtin=False): self.name = name self.path = path self.enabled = enabled self.status = '' self.can_update = False + self.is_builtin = is_builtin repo = None try: @@ -79,11 +81,19 @@ def list_extensions(): if not os.path.isdir(extensions_dir): return - for dirname in sorted(os.listdir(extensions_dir)): - path = os.path.join(extensions_dir, dirname) - if not os.path.isdir(path): - continue + paths = [] + for dirname in [extensions_dir, extensions_builtin_dir]: + if not os.path.isdir(dirname): + return - extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) + for extension_dirname in sorted(os.listdir(dirname)): + path = os.path.join(dirname, extension_dirname) + if not os.path.isdir(path): + continue + + paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) + + for dirname, path, is_builtin in paths: + extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) extensions.append(extension) diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py deleted file mode 100644 index 8c4db44a..00000000 --- a/modules/ldsr_model.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -import sys -import traceback - -from basicsr.utils.download_util import load_file_from_url - -from modules.upscaler import Upscaler, UpscalerData -from modules.ldsr_model_arch import LDSR -from modules import shared - - -class UpscalerLDSR(Upscaler): - def __init__(self, user_path): - self.name = "LDSR" - self.user_path = user_path - self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" - self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" - super().__init__() - scaler_data = UpscalerData("LDSR", None, self) - self.scalers = [scaler_data] - - def load_model(self, path: str): - # Remove incorrect project.yaml file if too big - yaml_path = os.path.join(self.model_path, "project.yaml") - old_model_path = os.path.join(self.model_path, "model.pth") - new_model_path = os.path.join(self.model_path, "model.ckpt") - if os.path.exists(yaml_path): - statinfo = os.stat(yaml_path) - if statinfo.st_size >= 10485760: - print("Removing invalid LDSR YAML file.") - os.remove(yaml_path) - if os.path.exists(old_model_path): - print("Renaming model from model.pth to model.ckpt") - os.rename(old_model_path, new_model_path) - model = load_file_from_url(url=self.model_url, model_dir=self.model_path, - file_name="model.ckpt", progress=True) - yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, - file_name="project.yaml", progress=True) - - try: - return LDSR(model, yaml) - - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - return None - - def do_upscale(self, img, path): - ldsr = self.load_model(path) - if ldsr is None: - print("NO LDSR!") - return img - ddim_steps = shared.opts.ldsr_steps - return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py deleted file mode 100644 index 90e0a2f0..00000000 --- a/modules/ldsr_model_arch.py +++ /dev/null @@ -1,230 +0,0 @@ -import gc -import time -import warnings - -import numpy as np -import torch -import torchvision -from PIL import Image -from einops import rearrange, repeat -from omegaconf import OmegaConf - -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import instantiate_from_config, ismap - -warnings.filterwarnings("ignore", category=UserWarning) - - -# Create LDSR Class -class LDSR: - def load_model_from_config(self, half_attention): - print(f"Loading model from {self.modelPath}") - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] - config = OmegaConf.load(self.yamlPath) - model = instantiate_from_config(config.model) - model.load_state_dict(sd, strict=False) - model.cuda() - if half_attention: - model = model.half() - - model.eval() - return {"model": model} - - def __init__(self, model_path, yaml_path): - self.modelPath = model_path - self.yamlPath = yaml_path - - @staticmethod - def run(model, selected_path, custom_steps, eta): - example = get_cond(selected_path) - - n_runs = 1 - guider = None - ckwargs = None - ddim_use_x0_pred = False - temperature = 1. - eta = eta - custom_shape = None - - height, width = example["image"].shape[1:3] - split_input = height >= 128 and width >= 128 - - if split_input: - ks = 128 - stride = 64 - vqf = 4 # - model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), - "vqf": vqf, - "patch_distributed_vq": True, - "tie_braker": False, - "clip_max_weight": 0.5, - "clip_min_weight": 0.01, - "clip_max_tie_weight": 0.5, - "clip_min_tie_weight": 0.01} - else: - if hasattr(model, "split_input_params"): - delattr(model, "split_input_params") - - x_t = None - logs = None - for n in range(n_runs): - if custom_shape is not None: - x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) - x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) - - logs = make_convolutional_sample(example, model, - custom_steps=custom_steps, - eta=eta, quantize_x0=False, - custom_shape=custom_shape, - temperature=temperature, noise_dropout=0., - corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, - ddim_use_x0_pred=ddim_use_x0_pred - ) - return logs - - def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): - model = self.load_model_from_config(half_attention) - - # Run settings - diffusion_steps = int(steps) - eta = 1.0 - - down_sample_method = 'Lanczos' - - gc.collect() - torch.cuda.empty_cache() - - im_og = image - width_og, height_og = im_og.size - # If we can adjust the max upscale size, then the 4 below should be our variable - down_sample_rate = target_scale / 4 - wd = width_og * down_sample_rate - hd = height_og * down_sample_rate - width_downsampled_pre = int(np.ceil(wd)) - height_downsampled_pre = int(np.ceil(hd)) - - if down_sample_rate != 1: - print( - f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') - im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) - else: - print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - - # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts - pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size - im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - - logs = self.run(model["model"], im_padded, diffusion_steps, eta) - - sample = logs["sample"] - sample = sample.detach().cpu() - sample = torch.clamp(sample, -1., 1.) - sample = (sample + 1.) / 2. * 255 - sample = sample.numpy().astype(np.uint8) - sample = np.transpose(sample, (0, 2, 3, 1)) - a = Image.fromarray(sample[0]) - - # remove padding - a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) - - del model - gc.collect() - torch.cuda.empty_cache() - return a - - -def get_cond(selected_path): - example = dict() - up_f = 4 - c = selected_path.convert('RGB') - c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) - c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], - antialias=True) - c_up = rearrange(c_up, '1 c h w -> 1 h w c') - c = rearrange(c, '1 c h w -> 1 h w c') - c = 2. * c - 1. - - c = c.to(torch.device("cuda")) - example["LR_image"] = c - example["image"] = c_up - - return example - - -@torch.no_grad() -def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, - mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, - corrector_kwargs=None, x_t=None - ): - ddim = DDIMSampler(model) - bs = shape[0] - shape = shape[1:] - print(f"Sampling with eta = {eta}; steps: {steps}") - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, - normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, - mask=mask, x0=x0, temperature=temperature, verbose=False, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, x_t=x_t) - - return samples, intermediates - - -@torch.no_grad() -def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, - corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = dict() - - z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=not (hasattr(model, 'split_input_params') - and model.cond_stage_key == 'coordinates_bbox'), - return_original_cond=True) - - if custom_shape is not None: - z = torch.randn(custom_shape) - print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") - - z0 = None - - log["input"] = x - log["reconstruction"] = xrec - - if ismap(xc): - log["original_conditioning"] = model.to_rgb(xc) - if hasattr(model, 'cond_stage_key'): - log[model.cond_stage_key] = model.to_rgb(xc) - - else: - log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_model: - log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_key == 'class_label': - log[model.cond_stage_key] = xc[model.cond_stage_key] - - with model.ema_scope("Plotting"): - t0 = time.time() - - sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, - eta=eta, - quantize_x0=quantize_x0, mask=None, x0=z0, - temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, - x_t=x_T) - t1 = time.time() - - if ddim_use_x0_pred: - sample = intermediates['pred_x0'][-1] - - x_sample = model.decode_first_stage(sample) - - try: - x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) - log["sample_noquant"] = x_sample_noquant - log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: - pass - - log["sample"] = x_sample - log["time"] = t1 - t0 - - return log diff --git a/modules/modelloader.py b/modules/modelloader.py index 7d2f0ade..e647f6fa 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -124,10 +124,9 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): def load_upscalers(): - sd = shared.script_path # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ - modules_dir = os.path.join(sd, "modules") + modules_dir = os.path.join(shared.script_path, "modules") for file in os.listdir(modules_dir): if "_model.py" in file: model_name = file.replace("_model.py", "") @@ -136,22 +135,13 @@ def load_upscalers(): importlib.import_module(full_model) except: pass + datas = [] - c_o = vars(shared.cmd_opts) + commandline_options = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): name = cls.__name__ - module_name = cls.__module__ - module = importlib.import_module(module_name) - class_ = getattr(module, name) cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" - opt_string = None - try: - if cmd_name in c_o: - opt_string = c_o[cmd_name] - except: - pass - scaler = class_(opt_string) - for child in scaler.scalers: - datas.append(child) + scaler = cls(commandline_options.get(cmd_name, None)) + datas += scaler.scalers shared.sd_upscalers = datas diff --git a/modules/scunet_model.py b/modules/scunet_model.py deleted file mode 100644 index 52360241..00000000 --- a/modules/scunet_model.py +++ /dev/null @@ -1,87 +0,0 @@ -import os.path -import sys -import traceback - -import PIL.Image -import numpy as np -import torch -from basicsr.utils.download_util import load_file_from_url - -import modules.upscaler -from modules import devices, modelloader -from modules.scunet_model_arch import SCUNet as net - - -class UpscalerScuNET(modules.upscaler.Upscaler): - def __init__(self, dirname): - self.name = "ScuNET" - self.model_name = "ScuNET GAN" - self.model_name2 = "ScuNET PSNR" - self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" - self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" - self.user_path = dirname - super().__init__() - model_paths = self.find_models(ext_filter=[".pth"]) - scalers = [] - add_model2 = True - for file in model_paths: - if "http" in file: - name = self.model_name - else: - name = modelloader.friendly_name(file) - if name == self.model_name2 or file == self.model_url2: - add_model2 = False - try: - scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) - scalers.append(scaler_data) - except Exception: - print(f"Error loading ScuNET model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - if add_model2: - scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) - scalers.append(scaler_data2) - self.scalers = scalers - - def do_upscale(self, img: PIL.Image, selected_file): - torch.cuda.empty_cache() - - model = self.load_model(selected_file) - if model is None: - return img - - device = devices.device_scunet - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) - - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') - - def load_model(self, path: str): - device = devices.device_scunet - if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, - progress=True) - else: - filename = path - if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: - print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) - return None - - model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for k, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model - diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py deleted file mode 100644 index 43ca8d36..00000000 --- a/modules/scunet_model_arch.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), - dims=(1, 2)) - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 8202d8e5..dc45fcaa 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -50,9 +50,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) -parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) -parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) -parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") @@ -61,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -95,6 +92,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) script_loading.preload_extensions(extensions.extensions_dir, parser) +script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) cmd_opts = parser.parse_args() @@ -112,8 +110,8 @@ restricted_opts = { cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access -devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ -(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) +devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ + (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) device = devices.device weight_load_location = None if cmd_opts.lowram else "cpu" @@ -326,9 +324,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), - "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), - "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) diff --git a/modules/swinir_model.py b/modules/swinir_model.py deleted file mode 100644 index 483eabd4..00000000 --- a/modules/swinir_model.py +++ /dev/null @@ -1,157 +0,0 @@ -import contextlib -import os - -import numpy as np -import torch -from PIL import Image -from basicsr.utils.download_util import load_file_from_url -from tqdm import tqdm - -from modules import modelloader, devices -from modules.shared import cmd_opts, opts -from modules.swinir_model_arch import SwinIR as net -from modules.swinir_model_arch_v2 import Swin2SR as net2 -from modules.upscaler import Upscaler, UpscalerData - - -class UpscalerSwinIR(Upscaler): - def __init__(self, dirname): - self.name = "SwinIR" - self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ - "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ - "-L_x4_GAN.pth " - self.model_name = "SwinIR 4x" - self.user_path = dirname - super().__init__() - scalers = [] - model_files = self.find_models(ext_filter=[".pt", ".pth"]) - for model in model_files: - if "http" in model: - name = self.model_name - else: - name = modelloader.friendly_name(model) - model_data = UpscalerData(name, model, self) - scalers.append(model_data) - self.scalers = scalers - - def do_upscale(self, img, model_file): - model = self.load_model(model_file) - if model is None: - return img - model = model.to(devices.device_swinir) - img = upscale(img, model) - try: - torch.cuda.empty_cache() - except: - pass - return img - - def load_model(self, path, scale=4): - if "http" in path: - dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") - filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) - else: - filename = path - if filename is None or not os.path.exists(filename): - return None - if filename.endswith(".v2.pth"): - model = net2( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = net( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) - if not cmd_opts.no_half: - model = model.half() - return model - - -def upscale( - img, - model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - window_size=8, - scale=4, -): - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_swinir) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - -def inference(img, model, tile, tile_overlap, window_size, scale): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - for w_idx in w_idx_list: - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output diff --git a/modules/swinir_model_arch.py b/modules/swinir_model_arch.py deleted file mode 100644 index 863f42db..00000000 --- a/modules/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py deleted file mode 100644 index 0e28ae6e..00000000 --- a/modules/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# ----------------------------------------------------------------------------------- -# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ -# Written by Conde and Choi et al. -# ----------------------------------------------------------------------------------- - -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=[0, 0]): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size - self.num_heads = num_heads - - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) - else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pre-training. - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - #assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(self.norm1(x)) - - # FFN - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.reduction(x) - x = self.norm(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - flops += H * W * self.dim // 2 - return flops - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - pretrained_window_size (int): Local window size in pre-training. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - def _init_respostnorm(self): - for blk in self.blocks: - nn.init.constant_(blk.norm1.bias, 0) - nn.init.constant_(blk.norm1.weight, 0) - nn.init.constant_(blk.norm2.bias, 0) - nn.init.constant_(blk.norm2.weight, 0) - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - -class Upsample_hf(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - - -class Swin2SR(nn.Module): - r""" Swin2SR - A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(Swin2SR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - - if self.upsampler == 'pixelshuffle_hf': - self.layers_hf = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers_hf.append(layer) - - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffle_aux': - self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) - self.conv_before_upsample = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_after_aux = nn.Sequential( - nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffle_hf': - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.upsample_hf = Upsample_hf(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - self.conv_before_upsample_hf = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward_features_hf(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers_hf: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffle_aux': - bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) - bicubic = self.conv_bicubic(bicubic) - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - aux = self.conv_aux(x) # b, 3, LR_H, LR_W - x = self.conv_after_aux(aux) - x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] - x = self.conv_last(x) - aux = aux / self.img_range + self.mean - elif self.upsampler == 'pixelshuffle_hf': - # for classical SR with HF - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x_before = self.conv_before_upsample(x) - x_out = self.conv_last(self.upsample(x_before)) - - x_hf = self.conv_first_hf(x_before) - x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf - x_hf = self.conv_before_upsample_hf(x_hf) - x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) - x = x_out + x_hf - x_hf = x_hf / self.img_range + self.mean - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - if self.upsampler == "pixelshuffle_aux": - return x[:, :, :H*self.upscale, :W*self.upscale], aux - - elif self.upsampler == "pixelshuffle_hf": - x_out = x_out / self.img_range + self.mean - return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - - else: - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = Swin2SR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index 2eb0b684..3acb9b48 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -28,7 +28,6 @@ import modules.codeformer_model import modules.generation_parameters_copypaste as parameters_copypaste import modules.gfpgan_model import modules.hypernetworks.ui -import modules.ldsr_model import modules.scripts import modules.shared as shared import modules.styles diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 42667941..b487ac25 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -78,6 +78,12 @@ def extension_table(): """ for ext in extensions.extensions: + remote = "" + if ext.is_builtin: + remote = "built-in" + elif ext.remote: + remote = f"""{html.escape("built-in" if ext.is_builtin else ext.remote or '')}""" + if ext.can_update: ext_status = f"""""" else: @@ -86,7 +92,7 @@ def extension_table(): code += f""" - {html.escape(ext.remote or '')} + {remote} {ext_status} """ diff --git a/webui.py b/webui.py index 16e7ec1a..78204d11 100644 --- a/webui.py +++ b/webui.py @@ -53,10 +53,11 @@ def initialize(): codeformer.setup_model(cmd_opts.codeformer_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) - modelloader.load_upscalers() modules.scripts.load_scripts() + modelloader.load_upscalers() + modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) @@ -177,6 +178,8 @@ def webui(): print('Reloading custom scripts') modules.scripts.reload_scripts() + modelloader.load_upscalers() + print('Reloading modules: modules.ui') importlib.reload(modules.ui) print('Refreshing Model List') -- cgit v1.2.3 From 12ade469c8fb65468da072c6489283acfb420763 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Sun, 4 Dec 2022 12:33:15 -0300 Subject: add queuing by default to avoid timeout on client side when share=True --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 78204d11..98f3c645 100644 --- a/webui.py +++ b/webui.py @@ -137,7 +137,7 @@ def webui(): shared.demo = modules.ui.create_ui() - app, local_url, share_url = shared.demo.launch( + app, local_url, share_url = shared.demo.queue().launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, -- cgit v1.2.3 From fcf372e5d044164931f0b779d0c63e9dcfbafef0 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Sun, 4 Dec 2022 14:13:31 -0300 Subject: set default to avoid breaking stuff --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 98f3c645..c2d0c6be 100644 --- a/webui.py +++ b/webui.py @@ -137,7 +137,7 @@ def webui(): shared.demo = modules.ui.create_ui() - app, local_url, share_url = shared.demo.queue().launch( + app, local_url, share_url = shared.demo.queue(default_enabled=False).launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, -- cgit v1.2.3 From c0355caefe3d82e304e6d832699d581fc8f9fbf9 Mon Sep 17 00:00:00 2001 From: Jim Hays Date: Wed, 14 Dec 2022 21:01:32 -0500 Subject: Fix various typos --- README.md | 4 ++-- javascript/contextMenus.js | 24 ++++++++++++------------ javascript/progressbar.js | 12 ++++++------ javascript/ui.js | 2 +- modules/api/api.py | 18 +++++++++--------- modules/api/models.py | 2 +- modules/images.py | 4 ++-- modules/processing.py | 14 +++++++------- modules/safe.py | 4 ++-- modules/scripts.py | 4 ++-- modules/sd_hijack_inpainting.py | 6 +++--- modules/sd_hijack_unet.py | 2 +- modules/textual_inversion/dataset.py | 10 +++++----- modules/textual_inversion/textual_inversion.py | 16 ++++++++-------- scripts/prompt_matrix.py | 10 +++++----- webui.py | 4 ++-- 16 files changed, 68 insertions(+), 68 deletions(-) (limited to 'webui.py') diff --git a/README.md b/README.md index 55990581..556000fb 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Use VAEs - Estimated completion time in progress bar - API -- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. -- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) +- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. +- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions ## Installation and Running diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index fe67c42e..11bcce1b 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -9,7 +9,7 @@ contextMenuInit = function(){ function showContextMenu(event,element,menuEntries){ let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; - let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; + let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; let oldMenu = gradioApp().querySelector('#context-menu') if(oldMenu){ @@ -61,15 +61,15 @@ contextMenuInit = function(){ } - function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ - - currentItems = menuSpecs.get(targetEmementSelector) - + function appendContextMenuOption(targetElementSelector,entryName,entryFunction){ + + currentItems = menuSpecs.get(targetElementSelector) + if(!currentItems){ currentItems = [] - menuSpecs.set(targetEmementSelector,currentItems); + menuSpecs.set(targetElementSelector,currentItems); } - let newItem = {'id':targetEmementSelector+'_'+uid(), + let newItem = {'id':targetElementSelector+'_'+uid(), 'name':entryName, 'func':entryFunction, 'isNew':true} @@ -97,7 +97,7 @@ contextMenuInit = function(){ if(source.id && source.id.indexOf('check_progress')>-1){ return } - + let oldMenu = gradioApp().querySelector('#context-menu') if(oldMenu){ oldMenu.remove() @@ -117,7 +117,7 @@ contextMenuInit = function(){ }) }); eventListenerApplied=true - + } return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] @@ -152,8 +152,8 @@ addContextMenuEventListener = initResponse[2]; generateOnRepeat('#img2img_generate','#img2img_interrupt'); }) - let cancelGenerateForever = function(){ - clearInterval(window.generateOnRepeatInterval) + let cancelGenerateForever = function(){ + clearInterval(window.generateOnRepeatInterval) } appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) @@ -162,7 +162,7 @@ addContextMenuEventListener = initResponse[2]; appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#roll','Roll three', - function(){ + function(){ let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); setTimeout(function(){rollbutton.click()},100) setTimeout(function(){rollbutton.click()},200) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index d58737c4..d6323ed9 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -3,7 +3,7 @@ global_progressbars = {} galleries = {} galleryObservers = {} -// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running +// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running timeoutIds = {} function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ @@ -20,21 +20,21 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) - + if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ if(progressbar.innerText){ let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion'; if(document.title != newtitle){ - document.title = newtitle; + document.title = newtitle; } }else{ let newtitle = 'Stable Diffusion' if(document.title != newtitle){ - document.title = newtitle; + document.title = newtitle; } } } - + if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ global_progressbars[id_progressbar] = progressbar @@ -63,7 +63,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip skip.style.display = "none" } interrupt.style.display = "none" - + //disconnect observer once generation finished, so user can close selected image if they want if (galleryObservers[id_gallery]) { galleryObservers[id_gallery].disconnect(); diff --git a/javascript/ui.js b/javascript/ui.js index 2cb280e5..587dd782 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -100,7 +100,7 @@ function create_submit_args(args){ // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image. // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate. - // I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some. + // I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some. // If gradio at some point stops sending outputs, this may break something if(Array.isArray(res[res.length - 3])){ res[res.length - 3] = null diff --git a/modules/api/api.py b/modules/api/api.py index 89935a70..33845045 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -67,10 +67,10 @@ def encode_pil_to_base64(image): class Api: def __init__(self, app: FastAPI, queue_lock: Lock): if shared.cmd_opts.api_auth: - self.credenticals = dict() + self.credentials = dict() for auth in shared.cmd_opts.api_auth.split(","): user, password = auth.split(":") - self.credenticals[user] = password + self.credentials[user] = password self.router = APIRouter() self.app = app @@ -93,7 +93,7 @@ class Api: self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) - self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) @@ -102,9 +102,9 @@ class Api: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, **kwargs) - def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())): - if credenticals.username in self.credenticals: - if compare_digest(credenticals.password, self.credenticals[credenticals.username]): + def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())): + if credentials.username in self.credentials: + if compare_digest(credentials.password, self.credentials[credentials.username]): return True raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) @@ -239,7 +239,7 @@ class Api: def interrogateapi(self, interrogatereq: InterrogateRequest): image_b64 = interrogatereq.image if image_b64 is None: - raise HTTPException(status_code=404, detail="Image not found") + raise HTTPException(status_code=404, detail="Image not found") img = decode_base64_to_image(image_b64) img = img.convert('RGB') @@ -252,7 +252,7 @@ class Api: processed = deepbooru.model.tag(img) else: raise HTTPException(status_code=404, detail="Model not found") - + return InterrogateResponse(caption=processed) def interruptapi(self): @@ -308,7 +308,7 @@ class Api: def get_realesrgan_models(self): return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)] - def get_promp_styles(self): + def get_prompt_styles(self): styleList = [] for k in shared.prompt_styles.styles: style = shared.prompt_styles.styles[k] diff --git a/modules/api/models.py b/modules/api/models.py index f77951fc..a22bc6b3 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -128,7 +128,7 @@ class ExtrasBaseRequest(BaseModel): upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") - upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") + upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") diff --git a/modules/images.py b/modules/images.py index 8146f580..93a14289 100644 --- a/modules/images.py +++ b/modules/images.py @@ -429,7 +429,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory. basename (`str`): The base filename which will be applied to `filename pattern`. - seed, prompt, short_filename, + seed, prompt, short_filename, extension (`str`): Image file extension, default is `png`. pngsectionname (`str`): @@ -590,7 +590,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr) + print(f"Error parsing NovelAI image generation parameters:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return geninfo, items diff --git a/modules/processing.py b/modules/processing.py index 24c537d1..fe7f4faf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -147,11 +147,11 @@ class StableDiffusionProcessing(): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + image_conditioning = image_conditioning.to(x.dtype) return image_conditioning @@ -199,7 +199,7 @@ class StableDiffusionProcessing(): source_image * (1.0 - conditioning_mask), getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) ) - + # Encode the new masked image using first stage of network. conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) @@ -537,7 +537,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: for n in range(p.n_iter): if state.skipped: state.skipped = False - + if state.interrupted: break @@ -612,7 +612,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image.info["parameters"] = text output_images.append(image) - del x_samples_ddim + del x_samples_ddim devices.torch_gc() @@ -704,7 +704,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] - """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images""" + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -720,7 +720,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - # Avoid making the inpainting conditioning unless necessary as + # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) diff --git a/modules/safe.py b/modules/safe.py index 10460ad0..20e9d2fa 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -80,7 +80,7 @@ def check_pt(filename, extra_handler): # new pytorch format is a zip file with zipfile.ZipFile(filename) as z: check_zip_filenames(filename, z.namelist()) - + # find filename of data.pkl in zip file: '/data.pkl' data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] if len(data_pkl_filenames) == 0: @@ -108,7 +108,7 @@ def load(filename, *args, **kwargs): def load_with_extra(filename, extra_handler=None, *args, **kwargs): """ - this functon is intended to be used by extensions that want to load models with + this function is intended to be used by extensions that want to load models with some extra classes in them that the usual unpickler would find suspicious. Use the extra_handler argument to specify a function that takes module and field name as text, diff --git a/modules/scripts.py b/modules/scripts.py index 23ca195d..722f8685 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -36,7 +36,7 @@ class Script: def ui(self, is_img2img): """this function should create gradio UI elements. See https://gradio.app/docs/#components The return value should be an array of all components that are used in processing. - Values of those returned componenbts will be passed to run() and process() functions. + Values of those returned components will be passed to run() and process() functions. """ pass @@ -47,7 +47,7 @@ class Script: This function should return: - False if the script should not be shown in UI at all - - True if the script should be shown in UI if it's scelected in the scripts drowpdown + - True if the script should be shown in UI if it's selected in the scripts dropdown - script.AlwaysVisible if the script should be shown in UI at all times """ diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 938f9a58..d72f83fd 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -209,7 +209,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) - + if isinstance(c, dict): assert isinstance(unconditional_conditioning, dict) c_in = dict() @@ -278,7 +278,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) return x_prev, pred_x0, e_t - + # ================================================================================================= # Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. # Adapted from: @@ -326,7 +326,7 @@ def do_inpainting_hijack(): # most of this stuff seems to no longer be needed because it is already included into SD2.0 # LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint # p_sample_plms is needed because PLMS can't work with dicts as conditionings - # this file should be cleaned up later if weverything tuens out to work fine + # this file should be cleaned up later if everything turns out to work fine # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 1b9d7757..18daf8c1 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -4,7 +4,7 @@ import torch class TorchHijackForUnet: """ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; - this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64 + this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 """ def __getattr__(self, item): diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 2dc64c3c..88d68c76 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -28,9 +28,9 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None - + self.placeholder_token = placeholder_token self.width = width @@ -50,14 +50,14 @@ class PersonalizedBase(Dataset): self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] - + self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): if shared.state.interrupted: - raise Exception("inturrupted") + raise Exception("interrupted") try: image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: @@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader): self.collate_fn = collate_wrapper_random else: self.collate_fn = collate_wrapper - + class BatchLoader: def __init__(self, data): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e28c357a..daf3997b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -133,7 +133,7 @@ class EmbeddingDatabase: process_file(fullfn, fn) except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) + print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue @@ -194,7 +194,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): csv_writer.writeheader() epoch = (step - 1) // epoch_len - epoch_step = (step - 1) % epoch_len + epoch_step = (step - 1) % epoch_len csv_writer.writerow({ "step": step, @@ -270,9 +270,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed - + pin_memory = shared.opts.pin_memory - + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) latent_sampling_method = ds.latent_sampling_method @@ -295,12 +295,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ loss_step = 0 _loss_step = 0 #internal - + last_saved_file = "" last_saved_image = "" forced_filename = "" embedding_yet_to_be_embedded = False - + pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps-initial_step) * gradient_step): @@ -327,10 +327,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ c = shared.sd_model.cond_stage_model(batch.cond_text) loss = shared.sd_model(x, c)[0] / gradient_step del x - + _loss_step += loss.item() scaler.scale(loss).backward() - + # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index c53ca28c..4c79eaef 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -18,7 +18,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] - first_pocessed = None + first_processed = None state.job_count = len(xs) * len(ys) @@ -27,17 +27,17 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" processed = cell(x, y) - if first_pocessed is None: - first_pocessed = processed + if first_processed is None: + first_processed = processed res.append(processed.images[0]) grid = images.image_grid(res, rows=len(ys)) grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - first_pocessed.images = [grid] + first_processed.images = [grid] - return first_pocessed + return first_processed class Script(scripts.Script): diff --git a/webui.py b/webui.py index c2d0c6be..4b32e77d 100644 --- a/webui.py +++ b/webui.py @@ -153,8 +153,8 @@ def webui(): # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for # an attacker to trick the user into opening a malicious HTML page, which makes a request to the - # running web ui and do whatever the attcker wants, including installing an extension and - # runnnig its code. We disable this here. Suggested by RyotaK. + # running web ui and do whatever the attacker wants, including installing an extension and + # running its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] setup_cors(app) -- cgit v1.2.3 From 9fd457e21d6c809a69a1318f03d75f7b3e09b865 Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Thu, 15 Dec 2022 21:57:48 +0300 Subject: allow_credentials and allow_headers for api from https://fastapi.tiangolo.com/tutorial/cors/ --- webui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index c2d0c6be..13a4d14a 100644 --- a/webui.py +++ b/webui.py @@ -90,11 +90,11 @@ def initialize(): def setup_cors(app): if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: - app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins: - app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins_regex: - app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) def create_api(app): -- cgit v1.2.3 From 35e1017e3ea0a3ad9ec28c9b447200a70a65c0ae Mon Sep 17 00:00:00 2001 From: Akiba Date: Fri, 16 Dec 2022 20:43:09 +0800 Subject: fix: xformers --- modules/import_hook.py | 18 ++++++++++++++++++ webui.py | 1 + 2 files changed, 19 insertions(+) create mode 100644 modules/import_hook.py (limited to 'webui.py') diff --git a/modules/import_hook.py b/modules/import_hook.py new file mode 100644 index 00000000..eb10e4fd --- /dev/null +++ b/modules/import_hook.py @@ -0,0 +1,18 @@ +import builtins +import sys + +old_import = builtins.__import__ +IMPORT_BLACKLIST = [] + + +if "xformers" not in "".join(sys.argv): + IMPORT_BLACKLIST.append("xformers") + + +def import_hook(*args, **kwargs): + if args[0] in IMPORT_BLACKLIST: + raise ImportError("Import of %s is blacklisted" % args[0]) + return old_import(*args, **kwargs) + + +builtins.__import__ = import_hook diff --git a/webui.py b/webui.py index c2d0c6be..18ee5a3d 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from modules import import_hook from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path -- cgit v1.2.3 From 2d5a5076bb2a0c05cc27d75a1bcadab7f32a46d0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:38:21 +0300 Subject: Make it so that upscalers are not repeated when restarting UI. --- modules/modelloader.py | 20 ++++++++++++++++++++ webui.py | 14 +++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) (limited to 'webui.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index e647f6fa..6a1a7ac8 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): pass +builtin_upscaler_classes = [] +forbidden_upscaler_classes = set() + + +def list_builtin_upscalers(): + load_upscalers() + + builtin_upscaler_classes.clear() + builtin_upscaler_classes.extend(Upscaler.__subclasses__()) + + +def forbid_loaded_nonbuiltin_upscalers(): + for cls in Upscaler.__subclasses__(): + if cls not in builtin_upscaler_classes: + forbidden_upscaler_classes.add(cls) + + def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ @@ -139,6 +156,9 @@ def load_upscalers(): datas = [] commandline_options = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): + if cls in forbidden_upscaler_classes: + continue + name = cls.__name__ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" scaler = cls(commandline_options.get(cmd_name, None)) diff --git a/webui.py b/webui.py index 3aee8792..c7d55a97 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,5 @@ import os +import sys import threading import time import importlib @@ -55,8 +56,8 @@ def initialize(): gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) + modelloader.list_builtin_upscalers() modules.scripts.load_scripts() - modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() @@ -169,23 +170,22 @@ def webui(): modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) + print('Restarting UI...') sd_samplers.set_samplers() - print('Reloading extensions') extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) - print('Reloading custom scripts') + modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() modelloader.load_upscalers() - print('Reloading modules: modules.ui') - importlib.reload(modules.ui) - print('Refreshing Model List') + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + modules.sd_models.list_models() - print('Restarting Gradio') if __name__ == "__main__": -- cgit v1.2.3 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/errors.py | 25 +++++++++++++++++++++++-- modules/sd_models.py | 26 ++++++++++++++++++-------- modules/shared.py | 9 +++++++-- webui.py | 12 ++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) (limited to 'webui.py') diff --git a/modules/errors.py b/modules/errors.py index 372dc51a..a668c014 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -2,9 +2,30 @@ import sys import traceback +def print_error_explanation(message): + lines = message.strip().split("\n") + max_len = max([len(x) for x in lines]) + + print('=' * max_len, file=sys.stderr) + for line in lines: + print(line, file=sys.stderr) + print('=' * max_len, file=sys.stderr) + + +def display(e: Exception, task): + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + message = str(e) + if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: + print_error_explanation(""" +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. + """) + + def run(code, task): try: code() except Exception as e: - print(f"{task}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + display(task, e) diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model diff --git a/modules/shared.py b/modules/shared.py index 23657a93..7588c47b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading +from modules import localization, sd_vae, extensions, script_loading, errors from modules.paths import models_path, script_path, sd_path @@ -494,7 +494,12 @@ class Options: return False if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False return True diff --git a/webui.py b/webui.py index c7d55a97..13375e71 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook +from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path @@ -61,7 +61,15 @@ def initialize(): modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From 8111b5569d07c7ac3b695e28171aede728b4ae56 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 3 Jan 2023 20:43:05 -0500 Subject: Add support for PyTorch nightly and local builds --- modules/devices.py | 28 +++++++++++++++++++++++----- webui.py | 7 ++++++- 2 files changed, 29 insertions(+), 6 deletions(-) (limited to 'webui.py') diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/webui.py b/webui.py index 13375e71..ddfaea95 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import threading import time import importlib import signal -import threading +import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -13,6 +13,11 @@ from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path +import torch +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras -- cgit v1.2.3 From 65ed4421e609dda3112f236c13e4db14caa71364 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 13:55:50 +0300 Subject: add callback for when the script is unloaded --- modules/script_callbacks.py | 18 +++++++++++++++++- webui.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index de69fd9f..608c5300 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_script_unloaded=[], ) @@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -202,7 +211,7 @@ def on_app_started(callback): def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is - passed as an argument""" + passed as an argument; this function is also called when the script is reloaded. """ add_callback(callback_map['callbacks_model_loaded'], callback) @@ -279,3 +288,10 @@ def on_image_grid(callback): - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) diff --git a/webui.py b/webui.py index ff6eb6eb..733a06b5 100644 --- a/webui.py +++ b/webui.py @@ -187,12 +187,14 @@ def webui(): sd_samplers.set_samplers() + modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) modelloader.load_upscalers() for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: -- cgit v1.2.3 From 5e6566324bba20554bcc04f3dda798e560397f38 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 07:06:26 -0500 Subject: Always end version number with a digit --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 733a06b5..8737e593 100644 --- a/webui.py +++ b/webui.py @@ -16,7 +16,7 @@ from modules.paths import script_path import torch # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: - torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer -- cgit v1.2.3 From 1fbb6f9ebe48326a3b12ecf611105dbc4a46891e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 23:35:40 +0300 Subject: make a dropdown for prompt template selection --- modules/hypernetworks/hypernetwork.py | 7 ++++-- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++++------ modules/ui.py | 11 ++++++-- webui.py | 3 +++ 5 files changed, 45 insertions(+), 12 deletions(-) (limited to 'webui.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 32c67ccc..ea3f1db9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,6 +24,7 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = textual_inversion.textual_inversion_templates.get(template_filename, None) + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() diff --git a/modules/shared.py b/modules/shared.py index a1e10201..aa37c8ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 14be2c96..5420903f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,6 +2,7 @@ import os import sys import traceback import inspect +from collections import namedtuple import torch import tqdm @@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler -from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, - insert_image_data_embed, extract_image_data_embed, - caption_image_overlay) +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay from modules.textual_inversion.logging import save_settings_to_file +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" @@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert template_file, "Prompt template file is empty" - assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0, "Max steps must be positive" @@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." diff --git a/modules/ui.py b/modules/ui.py index ddfe1b1a..b6079aec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui +from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text @@ -1322,6 +1322,9 @@ def create_ui(): outputs=[process_focal_crop_row], ) + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with FormRow(): @@ -1345,7 +1348,11 @@ def create_ui(): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") diff --git a/webui.py b/webui.py index 8737e593..47d372c7 100644 --- a/webui.py +++ b/webui.py @@ -33,6 +33,7 @@ import modules.sd_models import modules.sd_vae import modules.txt2img import modules.script_callbacks +import modules.textual_inversion.textual_inversion import modules.ui from modules import modelloader @@ -67,6 +68,8 @@ def initialize(): modules.sd_vae.refresh_vae_list() + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + try: modules.sd_models.load_model() except Exception as e: -- cgit v1.2.3