aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--README.md17
-rw-r--r--javascript/hints.js2
-rw-r--r--javascript/progressbar.js1
-rw-r--r--javascript/textualInversion.js8
-rw-r--r--javascript/ui.js27
-rw-r--r--launch.py4
-rw-r--r--modules/devices.py3
-rw-r--r--modules/esrgan_model.py4
-rw-r--r--modules/gfpgan_model.py6
-rw-r--r--modules/images.py9
-rw-r--r--modules/img2img.py11
-rw-r--r--modules/interrogate.py1
-rw-r--r--modules/modelloader.py21
-rw-r--r--modules/paths.py1
-rw-r--r--modules/processing.py17
-rw-r--r--modules/scripts.py34
-rw-r--r--modules/scunet_model.py90
-rw-r--r--modules/scunet_model_arch.py265
-rw-r--r--modules/sd_hijack.py324
-rw-r--r--modules/sd_hijack_optimizations.py164
-rw-r--r--modules/sd_models.py28
-rw-r--r--modules/sd_samplers.py20
-rw-r--r--modules/shared.py18
-rw-r--r--modules/swinir_model.py27
-rw-r--r--modules/textual_inversion/dataset.py78
-rw-r--r--modules/textual_inversion/preprocess.py75
-rw-r--r--modules/textual_inversion/textual_inversion.py271
-rw-r--r--modules/textual_inversion/ui.py40
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py293
-rw-r--r--requirements.txt2
-rw-r--r--requirements_versions.txt1
-rw-r--r--scripts/sd_upscale.py6
-rw-r--r--style.css10
-rw-r--r--textual_inversion_templates/style.txt19
-rw-r--r--textual_inversion_templates/style_filewords.txt19
-rw-r--r--textual_inversion_templates/subject.txt27
-rw-r--r--textual_inversion_templates/subject_filewords.txt27
-rw-r--r--webui.py65
40 files changed, 1612 insertions, 428 deletions
diff --git a/.gitignore b/.gitignore
index 3532dab3..7afc9395 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,3 +25,4 @@ __pycache__
/.idea
notification.mp3
/SwinIR
+/textual_inversion
diff --git a/README.md b/README.md
index 5ded94f9..ec3d7532 100644
--- a/README.md
+++ b/README.md
@@ -11,12 +11,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
-- Prompt
-- Stable Diffusion upscale
+- Prompt Matrix
+- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
- - a man in a ((txuedo)) - will pay more attentinoto tuxedo
- - a man in a (txuedo:1.21) - alternative syntax
-- Loopback, run img2img procvessing multiple times
+ - a man in a ((tuxedo)) - will pay more attention to tuxedo
+ - a man in a (tuxedo:1.21) - alternative syntax
+- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
- have as many embeddings as you want and use any names you like for them
@@ -35,15 +35,15 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- 4GB video card support (also reports of 2GB working)
- Correct seeds for batches
- Prompt length validation
- - get length of prompt in tokensas you type
- - get a warning after geenration if some text was truncated
+ - get length of prompt in tokens as you type
+ - get a warning after generation if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page
-- Running arbitrary python code from UI (must run with commandline flag to enable)
+- Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
@@ -113,6 +113,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
+- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
diff --git a/javascript/hints.js b/javascript/hints.js
index 84694eeb..e72e9338 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -15,7 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
- "\uD83D\uDCC2": "Open images output directory",
+ "\u{1f4c2}": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 21f25b38..1e297abb 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
+ check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
new file mode 100644
index 00000000..8061be08
--- /dev/null
+++ b/javascript/textualInversion.js
@@ -0,0 +1,8 @@
+
+
+function start_training_textual_inversion(){
+ requestProgress('ti')
+ gradioApp().querySelector('#ti_error').innerHTML=''
+
+ return args_to_array(arguments)
+}
diff --git a/javascript/ui.js b/javascript/ui.js
index bfe02410..b1053201 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -199,6 +199,26 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
+function update_txt2img_tokens(...args) {
+ update_token_counter("txt2img_token_button")
+ if (args.length == 2)
+ return args[0]
+ return args;
+}
+
+function update_img2img_tokens(...args) {
+ update_token_counter("img2img_token_button")
+ if (args.length == 2)
+ return args[0]
+ return args;
+}
+
+function update_token_counter(button_id) {
+ if (token_timeout)
+ clearTimeout(token_timeout);
+ token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+}
+
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
@@ -207,8 +227,7 @@ function submit_prompt(event, generate_button_id) {
}
}
-function update_token_counter(button_id) {
- if (token_timeout)
- clearTimeout(token_timeout);
- token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+function restart_reload(){
+ document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
+ setTimeout(function(){location.reload()},2000)
}
diff --git a/launch.py b/launch.py
index d2793ed2..57405fea 100644
--- a/launch.py
+++ b/launch.py
@@ -15,6 +15,7 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
+clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
@@ -111,6 +112,9 @@ if not skip_torch_cuda_test:
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
+if not is_installed("clip"):
+ run_pip(f"install {clip_package}", "clip")
+
os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
diff --git a/modules/devices.py b/modules/devices.py
index 08bb26d6..5d9c7a07 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-
device = get_optimal_device()
device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device
-
+dtype = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index ea91abfe..4aed9283 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -73,8 +73,8 @@ def fix_model_layers(crt_model, pretrained_net):
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
- self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
- self.model_name = "ESRGAN 4x"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
+ self.model_name = "ESRGAN_4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index fcd8544a..f1a564b7 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -97,11 +97,7 @@ def setup_model(dirname):
return "GFPGAN"
def restore(self, np_image):
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- return np_image
+ return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
diff --git a/modules/images.py b/modules/images.py
index f1aed5d6..bba55158 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -311,7 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
+
+ #currently disabled if using the save button, will work otherwise
+ # if enabled it will cause a bug because styles is not included in the save_files data dictionary
+ if hasattr(p, "styles"):
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"] or "None"), replace_spaces=False))
+
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
@@ -374,7 +379,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
+ dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
diff --git a/modules/img2img.py b/modules/img2img.py
index 03e934e9..2ff8e261 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
+ save_normally = output_dir == ''
+
p.do_not_save_grid = True
- p.do_not_save_samples = True
+ p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
@@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
- processed_image.save(os.path.join(output_dir, filename))
+ if not save_normally:
+ processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
@@ -103,7 +106,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
)
- print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
+
+ if shared.cmd_opts.enable_console_prompts:
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur
diff --git a/modules/interrogate.py b/modules/interrogate.py
index f62a4745..eed87144 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+
class InterrogateModels:
blip_model = None
clip_model = None
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 8c862b42..b0f2f33d 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -5,7 +5,6 @@ import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
-
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
@@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
- full_path = os.path.join(place, file)
+ full_path = file
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
@@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers():
+ sd = shared.script_path
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
+ # so we'll try to import any _model.py files before looking in __subclasses__
+ modules_dir = os.path.join(sd, "modules")
+ for file in os.listdir(modules_dir):
+ if "_model.py" in file:
+ model_name = file.replace("_model.py", "")
+ full_model = f"modules.{model_name}_model"
+ try:
+ importlib.import_module(full_model)
+ except:
+ pass
datas = []
+ c_o = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
- cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None
try:
- opt_string = shared.opts.__getattr__(cmd_name)
+ if cmd_name in c_o:
+ opt_string = c_o[cmd_name]
except:
pass
scaler = class_(opt_string)
diff --git a/modules/paths.py b/modules/paths.py
index ceb80417..606f7d66 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -20,7 +20,6 @@ path_dirs = [
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
- (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
diff --git a/modules/processing.py b/modules/processing.py
index 7eeb5191..0a4b6198 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.styles: str = styles
+ self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -79,7 +79,7 @@ class StableDiffusionProcessing:
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
-
+ self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
@@ -130,7 +130,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
-
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
- "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p)
- os.makedirs(p.outpath_samples, exist_ok=True)
- os.makedirs(p.outpath_grids, exist_ok=True)
+ if p.outpath_samples is not None:
+ os.makedirs(p.outpath_samples, exist_ok=True)
+
+ if p.outpath_grids is not None:
+ os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
- model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
diff --git a/modules/scripts.py b/modules/scripts.py
index 7c3bd5e7..45230f9a 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -162,6 +162,40 @@ class ScriptRunner:
return processed
+ def reload_sources(self):
+ for si, script in list(enumerate(self.scripts)):
+ with open(script.filename, "r", encoding="utf8") as file:
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
+ text = file.read()
+
+ from types import ModuleType
+
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+
+ for key, script_class in module.__dict__.items():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
+def reload_script_body_only():
+ scripts_txt2img.reload_sources()
+ scripts_img2img.reload_sources()
+
+
+def reload_scripts(basedir):
+ global scripts_txt2img, scripts_img2img
+
+ scripts_data.clear()
+ load_scripts(basedir)
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
new file mode 100644
index 00000000..7987ac14
--- /dev/null
+++ b/modules/scunet_model.py
@@ -0,0 +1,90 @@
+import os.path
+import sys
+import traceback
+
+import PIL.Image
+import numpy as np
+import torch
+from basicsr.utils.download_util import load_file_from_url
+
+import modules.upscaler
+from modules import shared, modelloader
+from modules.paths import models_path
+from modules.scunet_model_arch import SCUNet as net
+
+
+class UpscalerScuNET(modules.upscaler.Upscaler):
+ def __init__(self, dirname):
+ self.name = "ScuNET"
+ self.model_path = os.path.join(models_path, self.name)
+ self.model_name = "ScuNET GAN"
+ self.model_name2 = "ScuNET PSNR"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
+ self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
+ self.user_path = dirname
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pth"])
+ scalers = []
+ add_model2 = True
+ for file in model_paths:
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+ if name == self.model_name2 or file == self.model_url2:
+ add_model2 = False
+ try:
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
+ scalers.append(scaler_data)
+ except Exception:
+ print(f"Error loading ScuNET model: {file}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ if add_model2:
+ scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
+ scalers.append(scaler_data2)
+ self.scalers = scalers
+
+ def do_upscale(self, img: PIL.Image, selected_file):
+ torch.cuda.empty_cache()
+
+ model = self.load_model(selected_file)
+ if model is None:
+ return img
+
+ device = shared.device
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.moveaxis(img, 2, 0) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(shared.device)
+
+ img = img.to(device)
+ with torch.no_grad():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ torch.cuda.empty_cache()
+ return PIL.Image.fromarray(output, 'RGB')
+
+ def load_model(self, path: str):
+ device = shared.device
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
+ progress=True)
+ else:
+ filename = path
+ if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
+ print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
+ return None
+
+ model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+ model.load_state_dict(torch.load(filename), strict=True)
+ model.eval()
+ for k, v in model.named_parameters():
+ v.requires_grad = False
+ model = model.to(device)
+
+ return model
+
diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py
new file mode 100644
index 00000000..972a2639
--- /dev/null
+++ b/modules/scunet_model_arch.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from timm.models.layers import trunc_normal_, DropPath
+
+
+class WMSA(nn.Module):
+ """ Self-attention module in Swin Transformer
+ """
+
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
+ super(WMSA, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.head_dim = head_dim
+ self.scale = self.head_dim ** -0.5
+ self.n_heads = input_dim // head_dim
+ self.window_size = window_size
+ self.type = type
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
+
+ self.relative_position_params = nn.Parameter(
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
+
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
+
+ trunc_normal_(self.relative_position_params, std=.02)
+ self.relative_position_params = torch.nn.Parameter(
+ self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
+ 2).transpose(
+ 0, 1))
+
+ def generate_mask(self, h, w, p, shift):
+ """ generating the mask of SW-MSA
+ Args:
+ shift: shift parameters in CyclicShift.
+ Returns:
+ attn_mask: should be (1 1 w p p),
+ """
+ # supporting sqaure.
+ attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
+ if self.type == 'W':
+ return attn_mask
+
+ s = p - shift
+ attn_mask[-1, :, :s, :, s:, :] = True
+ attn_mask[-1, :, s:, :, :s, :] = True
+ attn_mask[:, -1, :, :s, :, s:] = True
+ attn_mask[:, -1, :, s:, :, :s] = True
+ attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
+ return attn_mask
+
+ def forward(self, x):
+ """ Forward pass of Window Multi-head Self-attention module.
+ Args:
+ x: input tensor with shape of [b h w c];