diff options
42 files changed, 2487 insertions, 1451 deletions
@@ -1,10 +1,13 @@ __pycache__ -/ESRGAN +*.ckpt +*.pth +/ESRGAN/* +/SwinIR/* /repositories /venv /tmp /model.ckpt -/models/**/*.ckpt +/models/**/* /GFPGANv1.3.pth /gfpgan/weights/*.pth /ui-config.json @@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion. 
+Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
+
## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
- Original txt2img and img2img modes
- One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
-- Prompt matrix
+- Prompt
- Stable Diffusion upscale
-- Attention
-- Loopback
-- X/Y plot
+- Attention, specify parts of text that the model should pay more attention to
+ - a man in a ((txuedo)) - will pay more attentinoto tuxedo
+ - a man in a (txuedo:1.21) - alternative syntax
+- Loopback, run img2img procvessing multiple times
+- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
+ - have as many embeddings as you want and use any names you like for them
+ - use multiple embeddings with different numbers of vectors per token
+ - works with half precision floating point numbers
- Extras tab with:
- GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler
- - ESRGAN, neural network with a lot of third party models
+ - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler
- LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options
- Sampling method selection
- Interrupt processing at any time
-- 4GB video card support
-- Correct seeds for batches
+- 4GB video card support (also reports of 2GB working)
+- Correct seeds for batches
- Prompt length validation
-- Generation parameters added as text to PNG
-- Tab to view an existing picture's generation parameters
+ - get length of prompt in tokensas you type
+ - get a warning after geenration if some text was truncated
+- Generation parameters
+ - parameters you used to generate images are saved with that image
+ - in PNG chunks for PNG, in EXIF for JPEG
+ - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
+ - can be disabled in settings
- Settings page
-- Running custom code from UI
+- Running arbitrary python code from UI (must run with commandline flag to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
-- Tiling support: UI checkbox to create images that can be tiled like textures
+- Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview
-- Negative prompt
-- Styles
-- Variations
-- Seed resizing
-- CLIP interrogator
-- Prompt Editing
-- Batch Processing
+- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
+- Styles, a way to save part of prompt and easily apply them via dropdown later
+- Variations, a way to generate same image but with tiny differences
+- Seed resizing, a way to generate same image but at slightly different resolution
+- CLIP interrogator, a button that tries to guess prompt from an image
+- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
+- Batch Processing, process a group of files using img2img
- Img2img Alternative
-- Highres Fix
-- LDSR Upscaling
+- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
+- Reloading checkpoints on the fly
+- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
+- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
@@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
+## Contributing
+Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
+
## Documentation
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
diff --git a/SwinIR/put_swinir_models_here.txt b/SwinIR/put_swinir_models_here.txt deleted file mode 100644 index 8b137891..00000000 --- a/SwinIR/put_swinir_models_here.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/artists.csv b/artists.csv index c92d08f5..14ba2022 100644 --- a/artists.csv +++ b/artists.csv @@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white Leonora Carrington,0.73726475,scribbles
Hieronymus Bosch,0.7369955,scribbles
A. J. Casson,0.73666203,scribbles
-A.J.Casson,0.73666203,scribbles
Chaim Soutine,0.73662066,scribbles
Artur Bordalo,0.7364549,weird
Thomas Allom,0.68792284,fineart
@@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact Bastien L. Deharme,0.583349,special
František Jakub Prokyš,0.58782333,fineart
Jesper Ejsing,0.58782053,fineart
-Jesper Ejsing,0.58782053,fineart
Odd Nerdrum,0.53551745,digipa-high-impact
Tom Lovell,0.5877577,fineart
Ayami Kojima,0.5877416,fineart
diff --git a/ESRGAN/Put ESRGAN models here.txt b/embeddings/Place Textual Inversion embeddings here.txt index e69de29b..e69de29b 100644 --- a/ESRGAN/Put ESRGAN models here.txt +++ b/embeddings/Place Textual Inversion embeddings here.txt diff --git a/javascript/hints.js b/javascript/hints.js index 59dd770c..84694eeb 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -15,6 +15,7 @@ titles = { "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.", + "\uD83D\uDCC2": "Open images output directory", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", @@ -57,8 +58,8 @@ titles = { "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", - "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", - "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", + "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", + "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Loopback": "Process an image, use it as an input, repeat.", diff --git a/javascript/ui.js b/javascript/ui.js index 562d2552..bfe02410 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -186,10 +186,12 @@ onUiUpdate(function(){ if (!txt2img_textarea) { txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); + txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate")); } if (!img2img_textarea) { img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); + img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate")); } }) @@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; +function submit_prompt(event, generate_button_id) { + if (event.altKey && event.keyCode === 13) { + event.preventDefault(); + gradioApp().getElementById(generate_button_id).click(); + return; + } +} + function update_token_counter(button_id) { if (token_timeout) clearTimeout(token_timeout); @@ -1,5 +1,4 @@ # this scripts installs necessary requirements and launches main program in webui.py
-
import subprocess
import os
import sys
@@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
-k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5")
+k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
-ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args)
@@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
-# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
-git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
@@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI") sys.argv += args
+if "--exit" in args:
+ print("Exiting because of --exit argument")
+ exit(0)
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
diff --git a/models/Put Stable Diffusion checkpoints here.txt b/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt index e69de29b..e69de29b 100644 --- a/models/Put Stable Diffusion checkpoints here.txt +++ b/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py new file mode 100644 index 00000000..e62c6657 --- /dev/null +++ b/modules/bsrgan_model.py @@ -0,0 +1,78 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.bsrgan_model_arch import RRDBNet +from modules.paths import models_path + + +class UpscalerBSRGAN(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "BSRGAN" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "BSRGAN 4x" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pt", ".pth"]) + scalers = [] + if len(model_paths) == 0: + scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) + scalers.append(scaler_data) + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading BSRGAN model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + model = self.load_model(selected_file) + if model is None: + return img + model.to(shared.device) + torch.cuda.empty_cache() + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(filename) or filename is None: + print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr) + return None + model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + return model + diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py new file mode 100644 index 00000000..cb4d1c13 --- /dev/null +++ b/modules/bsrgan_model_arch.py @@ -0,0 +1,102 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.sf = sf + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.sf==4: + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + if self.sf==4: + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out
\ No newline at end of file diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 2177291a..8769e1db 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,31 +5,31 @@ import traceback import cv2
import torch
-from modules import shared, devices
-from modules.paths import script_path
-import modules.shared
import modules.face_restoration
-from importlib import reload
+import modules.shared
+from modules import shared, devices, modelloader
+from modules.paths import script_path, models_path
-# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
-# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# codeformer people made a choice to include modified basicsr library to their project which makes
+# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
-
-pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_dir = "Codeformer"
+model_path = os.path.join(models_path, model_dir)
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None
-def setup_codeformer():
+
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
-
- # both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
- #stored_sys_path = sys.path
- #sys.path = [path] + sys.path
-
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
@@ -44,18 +44,23 @@ def setup_codeformer(): def name(self):
return "CodeFormer"
- def __init__(self):
+ def __init__(self, dirname):
self.net = None
self.face_helper = None
+ self.cmd_dir = dirname
def create_models(self):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
-
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
+ if len(model_paths) != 0:
+ ckpt_path = model_paths[0]
+ else:
+ print("Unable to load codeformer model.")
+ return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
@@ -74,6 +79,9 @@ def setup_codeformer(): original_resolution = np_image.shape[0:2]
self.create_models()
+ if self.net is None or self.face_helper is None:
+ return np_image
+
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -116,7 +124,7 @@ def setup_codeformer(): have_codeformer = True
global codeformer
- codeformer = FaceRestorerCodeFormer()
+ codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 7f3baf31..ea91abfe 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,26 +1,22 @@ import os
-import sys
-import traceback
import numpy as np
import torch
from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-from modules import shared
-from modules.shared import opts
+from modules import shared, modelloader, images
from modules.devices import has_mps
-import modules.images
+from modules.paths import models_path
+from modules.upscaler import Upscaler, UpscalerData
+from modules.shared import opts
-def load_model(filename):
+def fix_model_layers(crt_model, pretrained_net):
# this code is adapted from https://github.com/xinntao/ESRGAN
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
-
if 'conv_first.weight' in pretrained_net:
- crt_model.load_state_dict(pretrained_net)
- return crt_model
+ return pretrained_net
if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
@@ -72,9 +68,59 @@ def load_model(filename): crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
- crt_model.load_state_dict(crt_net)
- crt_model.eval()
- return crt_model
+ return crt_net
+
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+ self.model_name = "ESRGAN 4x"
+ self.scalers = []
+ self.user_path = dirname
+ self.model_path = os.path.join(models_path, self.name)
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+
+ scaler_data = UpscalerData(name, file, self, 4)
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(shared.device)
+ img = esrgan_upscale(model, img)
+ return img
+
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="%s.pth" % self.model_name,
+ progress=True)
+ else:
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_path, filename))
+ return None
+
+ pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+
+ pretrained_net = fix_model_layers(crt_model, pretrained_net)
+ crt_model.load_state_dict(pretrained_net)
+ crt_model.eval()
+
+ return crt_model
+
def upscale_without_tiling(model, img):
img = np.array(img)
@@ -95,7 +141,7 @@ def esrgan_upscale(model, img): if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img)
- grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = []
scale_factor = 1
@@ -110,32 +156,6 @@ def esrgan_upscale(model, img): newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
- newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = modules.images.combine_grid(newgrid)
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
return output
-
-
-class UpscalerESRGAN(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.model = load_model(filename)
-
- def do_upscale(self, img):
- model = self.model.to(shared.device)
- img = esrgan_upscale(model, img)
- return img
-
-
-def load_models(dirname):
- for file in os.listdir(dirname):
- path = os.path.join(dirname, file)
- model_name, extension = os.path.splitext(file)
-
- if extension != '.pt' and extension != '.pth':
- continue
-
- try:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
- except Exception:
- print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/extras.py b/modules/extras.py index 38b86167..1d9e64e5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs = []
for image, image_name in zip(imageArr, imageNameArr):
+ if image is None:
+ return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v c = cached_images.get(key)
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.upscale(image, image.width * resize, image.height * resize)
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c
return c
@@ -143,7 +145,7 @@ def run_pnginfo(image): return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half):
+def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -191,8 +193,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int if save_as_half:
theta_0[key] = theta_0[key].half()
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
- output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
+ filename = filename if custom_name == '' else (custom_name + '.ckpt')
+ output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index b1288f0c..a5fd9632 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,39 +1,25 @@ import os
import sys
import traceback
-from glob import glob
-from modules import shared, devices
-from modules.shared import cmd_opts
-from modules.paths import script_path
-import modules.face_restoration
-
-
-def gfpgan_model_path():
- from modules.shared import cmd_opts
-
- filemask = 'GFPGAN*.pth'
-
- if cmd_opts.gfpgan_model is not None:
- return cmd_opts.gfpgan_model
-
- places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
-
- filename = None
- for place in places:
- filename = next(iter(glob(os.path.join(place, filemask))), None)
- if filename is not None:
- break
-
- return filename
+import facexlib
+import gfpgan
+import modules.face_restoration
+from modules import shared, devices, modelloader
+from modules.paths import models_path
+model_dir = "GFPGAN"
+user_path = None
+model_path = os.path.join(models_path, model_dir)
+model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
+have_gfpgan = False
loaded_gfpgan_model = None
-def gfpgan():
+def gfpgann():
global loaded_gfpgan_model
-
+ global model_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model
@@ -41,7 +27,16 @@ def gfpgan(): if gfpgan_constructor is None:
return None
- model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
+ latest_file = max(models, key=os.path.getctime)
+ model_file = latest_file
+ else:
+ print("Unable to load gfpgan model!")
+ return None
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
@@ -50,8 +45,9 @@ def gfpgan(): def gfpgan_fix_faces(np_image):
global loaded_gfpgan_model
- model = gfpgan()
-
+ model = gfpgann()
+ if model is None:
+ return np_image
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
@@ -64,21 +60,39 @@ def gfpgan_fix_faces(np_image): return np_image
-have_gfpgan = False
gfpgan_constructor = None
-def setup_gfpgan():
- try:
- gfpgan_model_path()
- if os.path.exists(cmd_opts.gfpgan_dir):
- sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
- from gfpgan import GFPGANer
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ try:
+ from gfpgan import GFPGANer
+ from facexlib import detection, parsing
+ global user_path
global have_gfpgan
- have_gfpgan = True
-
global gfpgan_constructor
+
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
+ def my_load_file_from_url(**kwargs):
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+
+ def facex_load_file_from_url(**kwargs):
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ def facex_load_file_from_url2(**kwargs):
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+ user_path = dirname
+ have_gfpgan = True
gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
diff --git a/modules/images.py b/modules/images.py index 9458bf8d..f1aed5d6 100644 --- a/modules/images.py +++ b/modules/images.py @@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from fonts.ttf import Roboto
import string
-import modules.shared
from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts
@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height)
- dx = (w - tile_w) / (cols-1) if cols > 1 else 0
- dy = (h - tile_h) / (rows-1) if rows > 1 else 0
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): for col in range(cols):
x = int(col * dx)
- if x+tile_w >= w:
+ if x + tile_w >= w:
x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h))
@@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
- drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing
@@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
+ ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2
@@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height): if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
- upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
- return upscaler.upscale(im, w, h)
+ scale = max(w / im.width, h / im.height)
+
+ if scale > 1.0:
+ upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
+ assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
+
+ upscaler = upscalers[0]
+ im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
+
+ if im.width != w or im.height != h:
+ im = im.resize((w, h), resample=LANCZOS)
+
+ return im
if resize_mode == 0:
res = resize(im, width, height)
@@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height): invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
-re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
+re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
max_filename_part_length = 128
@@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt): if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt))
+ if "[prompt_no_styles]" in x:
+ prompt_no_style = prompt
+ for style in shared.prompt_styles.get_style_prompts(p.styles):
+ if len(style) > 0:
+ style_parts = [y for y in style.split("{prompt}")]
+ for part in style_parts:
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
+ x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
+
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
if "[prompt_words]" in x:
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
@@ -290,7 +311,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
@@ -303,6 +324,7 @@ def apply_filename_pattern(x, p, seed, prompt): return x
+
def get_next_sequence_number(path, basename):
"""
Determines and returns the next sequence number to use when saving an image in the specified directory.
@@ -316,7 +338,7 @@ def get_next_sequence_number(path, basename): prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
result = max(int(l[0]), result)
except ValueError:
@@ -324,6 +346,7 @@ def get_next_sequence_number(path, basename): return result + 1
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None:
file_decoration = ""
@@ -361,7 +384,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn = "a.png"
fullfn_without_extension = "a"
for i in range(500):
- fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn):
@@ -403,31 +426,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file.write(info + "\n")
-class Upscaler:
- name = "Lanczos"
-
- def do_upscale(self, img):
- return img
-
- def upscale(self, img, w, h):
- for i in range(3):
- if img.width >= w and img.height >= h:
- break
-
- img = self.do_upscale(img)
-
- if img.width != w or img.height != h:
- img = img.resize((int(w), int(h)), resample=LANCZOS)
-
- return img
-
-
-class UpscalerNone(Upscaler):
- name = "None"
-
- def upscale(self, img, w, h):
- return img
-
-
-modules.shared.sd_upscalers.append(UpscalerNone())
-modules.shared.sd_upscalers.append(Upscaler())
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 95e84659..1c1070fc 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -1,67 +1,56 @@ import os import sys import traceback -from collections import namedtuple from basicsr.utils.download_util import load_file_from_url -import modules.images +from modules.upscaler import Upscaler, UpscalerData +from modules.ldsr_model_arch import LDSR from modules import shared -from modules.paths import script_path +from modules.paths import models_path -LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) -ldsr_models = [] -have_ldsr = False -LDSR_obj = None - - -class UpscalerLDSR(modules.images.Upscaler): - def __init__(self, steps): - self.steps = steps +class UpscalerLDSR(Upscaler): + def __init__(self, user_path): self.name = "LDSR" - - def do_upscale(self, img): - return upscale_with_ldsr(img) - - -def add_lsdr(): - modules.shared.sd_upscalers.append(UpscalerLDSR(100)) - - -def setup_ldsr(): - path = modules.paths.paths.get("LDSR", None) - if path is None: - return - global have_ldsr - global LDSR_obj - try: - from LDSR import LDSR - model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" - yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" - repo_path = 'latent-diffusion/experiments/pretrained_models/' - model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path), - progress=True, file_name="model.chkpt") - yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path), - progress=True, file_name="project.yaml") - have_ldsr = True - LDSR_obj = LDSR(model_path, yaml_path) - - - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - have_ldsr = False - - -def upscale_with_ldsr(image): - setup_ldsr() - if not have_ldsr or LDSR_obj is None: - return image - - ddim_steps = shared.opts.ldsr_steps - pre_scale = shared.opts.ldsr_pre_down - post_scale = shared.opts.ldsr_post_down - - image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale) - return image + self.model_path = os.path.join(models_path, self.name) + self.user_path = user_path + self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" + self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + super().__init__() + scaler_data = UpscalerData("LDSR", None, self) + self.scalers = [scaler_data] + + def load_model(self, path: str): + # Remove incorrect project.yaml file if too big + yaml_path = os.path.join(self.model_path, "project.yaml") + old_model_path = os.path.join(self.model_path, "model.pth") + new_model_path = os.path.join(self.model_path, "model.ckpt") + if os.path.exists(yaml_path): + statinfo = os.stat(yaml_path) + if statinfo.st_size >= 10485760: + print("Removing invalid LDSR YAML file.") + os.remove(yaml_path) + if os.path.exists(old_model_path): + print("Renaming model from model.pth to model.ckpt") + os.rename(old_model_path, new_model_path) + model = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="model.ckpt", progress=True) + yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, + file_name="project.yaml", progress=True) + + try: + return LDSR(model, yaml) + + except Exception: + print("Error importing LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + def do_upscale(self, img, path): + ldsr = self.load_model(path) + if ldsr is None: + print("NO LDSR!") + return img + ddim_steps = shared.opts.ldsr_steps + return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py new file mode 100644 index 00000000..14db5076 --- /dev/null +++ b/modules/ldsr_model_arch.py @@ -0,0 +1,222 @@ +import gc +import time +import warnings + +import numpy as np +import torch +import torchvision +from PIL import Image +from einops import rearrange, repeat +from omegaconf import OmegaConf + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config, ismap + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Create LDSR Class +class LDSR: + def load_model_from_config(self, half_attention): + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + model = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model.cuda() + if half_attention: + model = model.half() + + model.eval() + return {"model": model} + + def __init__(self, model_path, yaml_path): + self.modelPath = model_path + self.yamlPath = yaml_path + + @staticmethod + def run(model, selected_path, custom_steps, eta): + example = get_cond(selected_path) + + n_runs = 1 + guider = None + ckwargs = None + ddim_use_x0_pred = False + temperature = 1. + eta = eta + custom_shape = None + + height, width = example["image"].shape[1:3] + split_input = height >= 128 and width >= 128 + + if split_input: + ks = 128 + stride = 64 + vqf = 4 # + model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01} + else: + if hasattr(model, "split_input_params"): + delattr(model, "split_input_params") + + x_t = None + logs = None + for n in range(n_runs): + if custom_shape is not None: + x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) + x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) + + logs = make_convolutional_sample(example, model, + custom_steps=custom_steps, + eta=eta, quantize_x0=False, + custom_shape=custom_shape, + temperature=temperature, noise_dropout=0., + corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, + ddim_use_x0_pred=ddim_use_x0_pred + ) + return logs + + def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): + model = self.load_model_from_config(half_attention) + + # Run settings + diffusion_steps = int(steps) + eta = 1.0 + + down_sample_method = 'Lanczos' + + gc.collect() + torch.cuda.empty_cache() + + im_og = image + width_og, height_og = im_og.size + # If we can adjust the max upscale size, then the 4 below should be our variable + down_sample_rate = target_scale / 4 + wd = width_og * down_sample_rate + hd = height_og * down_sample_rate + width_downsampled_pre = int(wd) + height_downsampled_pre = int(hd) + + if down_sample_rate != 1: + print( + f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') + im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) + else: + print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") + logs = self.run(model["model"], im_og, diffusion_steps, eta) + + sample = logs["sample"] + sample = sample.detach().cpu() + sample = torch.clamp(sample, -1., 1.) + sample = (sample + 1.) / 2. * 255 + sample = sample.numpy().astype(np.uint8) + sample = np.transpose(sample, (0, 2, 3, 1)) + a = Image.fromarray(sample[0]) + + del model + gc.collect() + torch.cuda.empty_cache() + return a + + +def get_cond(selected_path): + example = dict() + up_f = 4 + c = selected_path.convert('RGB') + c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) + c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], + antialias=True) + c_up = rearrange(c_up, '1 c h w -> 1 h w c') + c = rearrange(c, '1 c h w -> 1 h w c') + c = 2. * c - 1. + + c = c.to(torch.device("cuda")) + example["LR_image"] = c + example["image"] = c_up + + return example + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, + mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, + corrector_kwargs=None, x_t=None + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, + normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, + mask=mask, x0=x0, temperature=temperature, verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, x_t=x_t) + + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, + corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): + log = dict() + + z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, 'split_input_params') + and model.cond_stage_key == 'coordinates_bbox'), + return_original_cond=True) + + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + z0 = None + + log["input"] = x + log["reconstruction"] = xrec + + if ismap(xc): + log["original_conditioning"] = model.to_rgb(xc) + if hasattr(model, 'cond_stage_key'): + log[model.cond_stage_key] = model.to_rgb(xc) + + else: + log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_model: + log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_key == 'class_label': + log[model.cond_stage_key] = xc[model.cond_stage_key] + + with model.ema_scope("Plotting"): + t0 = time.time() + + sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, mask=None, x0=z0, + temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, + x_t=x_T) + t1 = time.time() + + if ddim_use_x0_pred: + sample = intermediates['pred_x0'][-1] + + x_sample = model.decode_first_stage(sample) + + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + log["sample"] = x_sample + log["time"] = t1 - t0 + + return log diff --git a/modules/modelloader.py b/modules/modelloader.py new file mode 100644 index 00000000..8c862b42 --- /dev/null +++ b/modules/modelloader.py @@ -0,0 +1,140 @@ +import glob +import os +import shutil +import importlib +from urllib.parse import urlparse + +from basicsr.utils.download_util import load_file_from_url + +from modules import shared +from modules.upscaler import Upscaler +from modules.paths import script_path, models_path + + +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: + """ + A one-and done loader to try finding the desired models in specified directories. + + @param download_name: Specify to download from model_url immediately. + @param model_url: If no other models are found, this will be downloaded on upscale. + @param model_path: The location to store/find models in. + @param command_path: A command-line argument to search for models in first. + @param ext_filter: An optional list of filename extensions to filter by + @return: A list of paths containing the desired model(s) + """ + output = [] + + if ext_filter is None: + ext_filter = [] + + try: + places = [] + + if command_path is not None and command_path != model_path: + pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') + if os.path.exists(pretrained_path): + print(f"Appending path: {pretrained_path}") + places.append(pretrained_path) + elif os.path.exists(command_path): + places.append(command_path) + + places.append(model_path) + + for place in places: + if os.path.exists(place): + for file in glob.iglob(place + '**/**', recursive=True): + full_path = os.path.join(place, file) + if os.path.isdir(full_path): + continue + if len(ext_filter) != 0: + model_name, extension = os.path.splitext(file) + if extension not in ext_filter: + continue + if file not in output: + output.append(full_path) + + if model_url is not None and len(output) == 0: + if download_name is not None: + dl = load_file_from_url(model_url, model_path, True, download_name) + output.append(dl) + else: + output.append(model_url) + + except Exception: + pass + + return output + + +def friendly_name(file: str): + if "http" in file: + file = urlparse(file).path + + file = os.path.basename(file) + model_name, extension = os.path.splitext(file) + return model_name + + +def cleanup_models(): + # This code could probably be more efficient if we used a tuple list or something to store the src/destinations + # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler + # somehow auto-register and just do these things... + root_path = script_path + src_path = models_path + dest_path = os.path.join(models_path, "Stable-diffusion") + move_files(src_path, dest_path, ".ckpt") + src_path = os.path.join(root_path, "ESRGAN") + dest_path = os.path.join(models_path, "ESRGAN") + move_files(src_path, dest_path) + src_path = os.path.join(root_path, "gfpgan") + dest_path = os.path.join(models_path, "GFPGAN") + move_files(src_path, dest_path) + src_path = os.path.join(root_path, "SwinIR") + dest_path = os.path.join(models_path, "SwinIR") + move_files(src_path, dest_path) + src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") + dest_path = os.path.join(models_path, "LDSR") + move_files(src_path, dest_path) + + +def move_files(src_path: str, dest_path: str, ext_filter: str = None): + try: + if not os.path.exists(dest_path): + os.makedirs(dest_path) + if os.path.exists(src_path): + for file in os.listdir(src_path): + fullpath = os.path.join(src_path, file) + if os.path.isfile(fullpath): + if ext_filter is not None: + if ext_filter not in file: + continue + print(f"Moving {file} from {src_path} to {dest_path}.") + try: + shutil.move(fullpath, dest_path) + except: + pass + if len(os.listdir(src_path)) == 0: + print(f"Removing empty folder: {src_path}") + shutil.rmtree(src_path, True) + except: + pass + + +def load_upscalers(): + datas = [] + for cls in Upscaler.__subclasses__(): + name = cls.__name__ + module_name = cls.__module__ + module = importlib.import_module(module_name) + class_ = getattr(module, name) + cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + opt_string = None + try: + opt_string = shared.opts.__getattr__(cmd_name) + except: + pass + scaler = class_(opt_string) + for child in scaler.scalers: + datas.append(child) + + shared.sd_upscalers = datas diff --git a/modules/paths.py b/modules/paths.py index df7b9d9a..ceb80417 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -3,9 +3,10 @@ import os import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+models_path = os.path.join(script_path, "models")
sys.path.insert(0, script_path)
-# search for directory of stable diffsuion in following palces
+# search for directory of stable diffusion in following places
sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
@@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths: assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
path_dirs = [
- (sd_path, 'ldm', 'Stable Diffusion'),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
- (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
- (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
- (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
- (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
+ (sd_path, 'ldm', 'Stable Diffusion', []),
+ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
+ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
+ (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
+ (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
+ (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
paths = {}
-for d, must_exist, what in path_dirs:
+for d, must_exist, what, options in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else:
d = os.path.abspath(d)
- sys.path.append(d)
+ if "atstart" in options:
+ sys.path.insert(0, d)
+ else:
+ sys.path.append(d)
paths[what] = d
diff --git a/modules/processing.py b/modules/processing.py index de5cda79..a838ebb3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -508,8 +508,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
- upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
- image = upscaler.upscale(image, self.width, self.height)
+ image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a6a25b28..e811eb9e 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res
+re_attention = re.compile(r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""", re.X)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ Example:
+
+ 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
+
+ produces:
+
+ [
+ ['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]
+ ]
+ """
-#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith('\\'):
+ res.append([text[1:], 1.0])
+ elif text == '(':
+ round_brackets.append(len(res))
+ elif text == '[':
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ')' and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == ']' and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ return res
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c32d6c4c..dc0123e0 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,119 +1,135 @@ +import os
import sys
import traceback
-from collections import namedtuple
import numpy as np
from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
-import modules.images
+from modules.upscaler import Upscaler, UpscalerData
+from modules.paths import models_path
from modules.shared import cmd_opts, opts
-RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
-realesrgan_models = []
-have_realesrgan = False
+class UpscalerRealESRGAN(Upscaler):
+ def __init__(self, path):
+ self.name = "RealESRGAN"
+ self.model_path = os.path.join(models_path, self.name)
+ self.user_path = path
+ super().__init__()
+ try:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+ self.enable = True
+ self.scalers = []
+ scalers = self.load_models(path)
+ for scaler in scalers:
+ if scaler.name in opts.realesrgan_enabled_models:
+ self.scalers.append(scaler)
+
+ except Exception:
+ print("Error importing Real-ESRGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ self.enable = False
+ self.scalers = []
+
+ def do_upscale(self, img, path):
+ if not self.enable:
+ return img
+
+ info = self.load_model(path)
+ if not os.path.exists(info.data_path):
+ print("Unable to load RealESRGAN model: %s" % info.name)
+ return img
+
+ upsampler = RealESRGANer(
+ scale=info.scale,
+ model_path=info.data_path,
+ model=info.model(),
+ half=not cmd_opts.no_half,
+ tile=opts.ESRGAN_tile,
+ tile_pad=opts.ESRGAN_tile_overlap,
+ )
+
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
+
+ image = Image.fromarray(upsampled)
+ return image
+
+ def load_model(self, path):
+ try:
+ info = None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ info = scaler
+
+ if info is None:
+ print(f"Unable to find model info: {path}")
+ return None
+
+ model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+ info.data_path = model_file
+ return info
+ except Exception as e:
+ print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
-def get_realesrgan_models():
+ def load_models(self, _):
+ return get_realesrgan_models(self)
+
+
+def get_realesrgan_models(scaler):
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
- RealesrganModelInfo(
- name="Real-ESRGAN General x4x3",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN General 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
),
- RealesrganModelInfo(
- name="Real-ESRGAN General WDN x4x3",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN General WDN 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
),
- RealesrganModelInfo(
- name="Real-ESRGAN AnimeVideo",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN AnimeVideo",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
),
- RealesrganModelInfo(
- name="Real-ESRGAN 4x plus",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN 4x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
),
- RealesrganModelInfo(
- name="Real-ESRGAN 4x plus anime 6B",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN 4x+ Anime6B",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
- RealesrganModelInfo(
- name="Real-ESRGAN 2x plus",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- netscale=2,
+ UpscalerData(
+ name="R-ESRGAN 2x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ scale=2,
+ upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
return models
except Exception as e:
- print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
+ print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
-
-
-class UpscalerRealESRGAN(modules.images.Upscaler):
- def __init__(self, upscaling, model_index):
- self.upscaling = upscaling
- self.model_index = model_index
- self.name = realesrgan_models[model_index].name
-
- def do_upscale(self, img):
- return upscale_with_realesrgan(img, self.upscaling, self.model_index)
-
-
-def setup_realesrgan():
- global realesrgan_models
- global have_realesrgan
-
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
-
- realesrgan_models = get_realesrgan_models()
- have_realesrgan = True
-
- for i, model in enumerate(realesrgan_models):
- if model.name in opts.realesrgan_enabled_models:
- modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
-
- except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
- have_realesrgan = False
-
-
-def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
- if not have_realesrgan:
- return image
-
- info = realesrgan_models[RealESRGAN_model_index]
-
- model = info.model()
- upsampler = RealESRGANer(
- scale=info.netscale,
- model_path=info.location,
- model=model,
- half=not cmd_opts.no_half,
- tile=opts.ESRGAN_tile,
- tile_pad=opts.ESRGAN_tile_overlap,
- )
-
- upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
-
- image = Image.fromarray(upsampled)
- return image
diff --git a/modules/scripts.py b/modules/scripts.py index 202374e6..7c3bd5e7 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -55,7 +55,7 @@ def load_scripts(basedir): if not os.path.exists(basedir):
return
- for filename in os.listdir(basedir):
+ for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename)
if not os.path.isfile(path):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4bc58fa2..317e0c4c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,6 +6,7 @@ import torch import numpy as np
from torch import einsum
+from modules import prompt_parser
from modules.shared import opts, device, cmd_opts
from ldm.util import default
@@ -204,6 +205,7 @@ class StableDiffusionModelHijack: param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
@@ -223,19 +225,25 @@ class StableDiffusionModelHijack: for fn in os.listdir(dirname):
try:
- process_file(os.path.join(dirname, fn), fn)
+ fullfn = os.path.join(dirname, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
- print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
self.clip = m.cond_stage_model
ldm.modules.diffusionmodules.model.nonlinearity = silu
@@ -255,6 +263,14 @@ class StableDiffusionModelHijack: self.layers = flatten(m)
+ def undo_hijack(self, m):
+ if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
@@ -269,6 +285,7 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
@@ -294,7 +311,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0:
self.token_mults[ident] = mult
- def process_text(self, text):
+
+ def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ id_start = self.wrapped.tokenizer.bos_token_id
+ id_end = self.wrapped.tokenizer.eos_token_id
+ maxlen = self.wrapped.max_length
+
+ if opts.enable_emphasis:
+ parsed = prompt_parser.parse_prompt_attention(line)
+ else:
+ parsed = [[line, 1.0]]
+
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
+
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ possible_matches = self.hijack.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ else:
+ found = False
+ for ids, word in possible_matches:
+ if tokens[i:i + len(ids)] == ids:
+ emb_len = int(self.hijack.word_embeddings[word].shape[0])
+ fixes.append((len(remade_tokens), word))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ i += len(ids) - 1
+ found = True
+ used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
+ break
+
+ if not found:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ i += 1
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ return remade_tokens, fixes, multipliers, token_count
+
+ def process_text(self, texts):
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_multipliers = []
+ for line in texts:
+ if line in cache:
+ remade_tokens, fixes, multipliers = cache[line]
+ else:
+ remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+
+ cache[line] = (remade_tokens, fixes, multipliers)
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+ def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
@@ -370,12 +472,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+ if opts.use_old_emphasis_implementation:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+ else:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
diff --git a/modules/sd_models.py b/modules/sd_models.py index dd47dffb..2539f14c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,14 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config
-from modules import shared
+from modules import shared, modelloader
+from modules.paths import models_path
+
+model_dir = "Stable-diffusion"
+model_path = os.path.abspath(os.path.join(models_path, model_dir))
+model_name = "sd-v1-4.ckpt"
+model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
+user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
@@ -23,21 +30,30 @@ except Exception: pass
+def setup_model(dirname):
+ global user_dir
+ user_dir = dirname
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ checkpoints_list.clear()
+ list_models()
+
+
def checkpoint_tiles():
- print(sorted([x.title for x in checkpoints_list.values()]))
return sorted([x.title for x in checkpoints_list.values()])
def list_models():
checkpoints_list.clear()
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
- model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
-
- def modeltitle(path, h):
+ def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
- if abspath.startswith(model_dir):
- name = abspath.replace(model_dir, '')
+ if user_dir is not None and abspath.startswith(user_dir):
+ name = abspath.replace(user_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)
@@ -46,21 +62,27 @@ def list_models(): shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
- return f'{name} [{h}]', shortname
+ return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
- title, model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
+ title, short_model_name = modeltitle(cmd_ckpt, h)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
+ shared.opts.sd_model_checkpoint = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
- print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
+ print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
+ for filename in model_list:
+ h = model_hash(filename)
+ title, short_model_name = modeltitle(filename, h)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
- if os.path.exists(model_dir):
- for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
- h = model_hash(filename)
- title, model_name = modeltitle(filename, h)
- checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
+
+def get_closet_checkpoint_match(searchString):
+ applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
+ if len(applicable) > 0:
+ return applicable[0]
+ return None
def model_hash(filename):
@@ -138,7 +160,7 @@ def load_model(): def reload_model_weights(sd_model, info=None):
- from modules import lowvram, devices
+ from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename:
@@ -149,8 +171,12 @@ def reload_model_weights(sd_model, info=None): else:
sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(sd_model)
+
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
+ sd_hijack.model_hijack.hijack(sd_model)
+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index fc0c94b4..dff89c09 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -4,7 +4,6 @@ import torch import tqdm
from PIL import Image
import inspect
-
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
@@ -23,6 +22,8 @@ samplers_k_diffusion = [ ('Heun', 'sample_heun', ['k_heun']),
('DPM2', 'sample_dpm_2', ['k_dpm_2']),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
]
samplers_data_k_diffusion = [
@@ -36,7 +37,7 @@ samplers = [ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
]
-samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
+samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
@@ -309,8 +310,13 @@ class KDiffusionSampler: x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
-
- samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
-
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = steps
+ else:
+ extra_params_kwargs['sigmas'] = sigmas
+ samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
diff --git a/modules/shared.py b/modules/shared.py index ec1e569b..ac968b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,26 +1,27 @@ -import sys
import argparse
+import datetime
import json
import os
+import sys
+
import gradio as gr
import tqdm
-import datetime
import modules.artists
-from modules.paths import script_path, sd_path
-from modules.devices import get_optimal_device
-import modules.styles
import modules.interrogate
import modules.memmon
import modules.sd_models
+import modules.styles
+from modules.devices import get_optimal_device
+from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
-
+model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
-parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
-parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
+parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
+parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
-parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
-parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
+parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
+parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
+parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
+parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
+parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
+parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
+parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
@@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
cmd_opts = parser.parse_args()
-
device = get_optimal_device()
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
@@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file
+
class State:
interrupted = False
job = ""
@@ -95,13 +101,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-
-modules.sd_models.list_models()
+# This was moved to webui.py with the other model "setup" calls.
+# modules.sd_models.list_models()
def realesrgan_models_names():
import modules.realesrgan_model
- return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]
+ return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo:
@@ -167,13 +173,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
- "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
- "ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
-
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
@@ -190,12 +193,13 @@ options_templates.update(options_section(('system', "System"), { }))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
diff --git a/modules/styles.py b/modules/styles.py index eeedcd08..d44dfc1a 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -53,6 +53,12 @@ class StyleDatabase: negative_prompt = row.get("negative_prompt", "")
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+ def get_style_prompts(self, styles):
+ return [self.styles.get(x, self.no_style).prompt for x in styles]
+
+ def get_negative_style_prompts(self, styles):
+ return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
+
def apply_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
diff --git a/modules/swinir.py b/modules/swinir.py deleted file mode 100644 index 8c534495..00000000 --- a/modules/swinir.py +++ /dev/null @@ -1,123 +0,0 @@ -import sys
-import traceback
-import cv2
-import os
-import contextlib
-import numpy as np
-from PIL import Image
-import torch
-import modules.images
-from modules.shared import cmd_opts, opts, device
-from modules.swinir_arch import SwinIR as net
-
-precision_scope = (
- torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-)
-
-
-def load_model(filename, scale=4):
- model = net(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
-
- pretrained_model = torch.load(filename)
- model.load_state_dict(pretrained_model["params_ema"], strict=True)
- if not cmd_opts.no_half:
- model = model.half()
- return model
-
-
-def load_models(dirname):
- for file in os.listdir(dirname):
- path = os.path.join(dirname, file)
- model_name, extension = os.path.splitext(file)
-
- if extension != ".pt" and extension != ".pth":
- continue
-
- try:
- modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
- except Exception:
- print(f"Error loading SwinIR model: {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
-
-def upscale(
- img,
- model,
- tile=opts.SWIN_tile,
- tile_overlap=opts.SWIN_tile_overlap,
- window_size=8,
- scale=4,
-):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
- with torch.no_grad(), precision_scope("cuda"):
- _, _, h_old, w_old = img.size()
- h_pad = (h_old // window_size + 1) * window_size - h_old
- w_pad = (w_old // window_size + 1) * window_size - w_old
- img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
- img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(img, model, tile, tile_overlap, window_size, scale)
- output = output[..., : h_old * scale, : w_old * scale]
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
- if output.ndim == 3:
- output = np.transpose(
- output[[2, 1, 0], :, :], (1, 2, 0)
- ) # CHW-RGB to HCW-BGR
- output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return Image.fromarray(output, "RGB")
-
-
-def inference(img, model, tile, tile_overlap, window_size, scale):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
-
- for h_idx in h_idx_list:
- for w_idx in w_idx_list:
- in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
- ].add_(out_patch_mask)
- output = E.div_(W)
-
- return output
-
-
-class UpscalerSwin(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.model = load_model(filename)
-
- def do_upscale(self, img):
- model = self.model.to(device)
- img = upscale(img, model)
- return img
diff --git a/modules/swinir_model.py b/modules/swinir_model.py new file mode 100644 index 00000000..41fda5a7 --- /dev/null +++ b/modules/swinir_model.py @@ -0,0 +1,139 @@ +import contextlib +import os + +import numpy as np +import torch +from PIL import Image +from basicsr.utils.download_util import load_file_from_url + +from modules import modelloader +from modules.paths import models_path +from modules.shared import cmd_opts, opts, device +from modules.swinir_model_arch import SwinIR as net +from modules.upscaler import Upscaler, UpscalerData + +precision_scope = ( + torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext +) + + +class UpscalerSwinIR(Upscaler): + def __init__(self, dirname): + self.name = "SwinIR" + self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ + "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ + "-L_x4_GAN.pth " + self.model_name = "SwinIR 4x" + self.model_path = os.path.join(models_path, self.name) + self.user_path = dirname + super().__init__() + scalers = [] + model_files = self.find_models(ext_filter=[".pt", ".pth"]) + for model in model_files: + if "http" in model: + name = self.model_name + else: + name = modelloader.friendly_name(model) + model_data = UpscalerData(name, model, self) + scalers.append(model_data) + self.scalers = scalers + + def do_upscale(self, img, model_file): + model = self.load_model(model_file) + if model is None: + return img + model = model.to(device) + img = upscale(img, model) + try: + torch.cuda.empty_cache() + except: + pass + return img + + def load_model(self, path, scale=4): + if "http" in path: + dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + + pretrained_model = torch.load(filename) + model.load_state_dict(pretrained_model["params_ema"], strict=True) + if not cmd_opts.no_half: + model = model.half() + return model + + +def upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + window_size=8, + scale=4, +): + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(device) + with torch.no_grad(), precision_scope("cuda"): + _, _, h_old, w_old = img.size() + h_pad = (h_old // window_size + 1) * window_size - h_old + w_pad = (w_old // window_size + 1) * window_size - w_old + img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] + img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] + output = inference(img, model, tile, tile_overlap, window_size, scale) + output = output[..., : h_old * scale, : w_old * scale] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose( + output[[2, 1, 0], :, :], (1, 2, 0) + ) # CHW-RGB to HCW-BGR + output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 + return Image.fromarray(output, "RGB") + + +def inference(img, model, tile, tile_overlap, window_size, scale): + # test the image tile by tile + b, c, h, w = img.size() + tile = min(tile, h, w) + assert tile % window_size == 0, "tile size should be a multiple of window_size" + sf = scale + + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) + W = torch.zeros_like(E, dtype=torch.half, device=device) + + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + output = E.div_(W) + + return output diff --git a/modules/swinir_arch.py b/modules/swinir_model_arch.py index a5eb9a36..461fb354 100644 --- a/modules/swinir_arch.py +++ b/modules/swinir_model_arch.py @@ -1,867 +1,867 @@ -# -----------------------------------------------------------------------------------
-# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
-# Originally Written by Ze Liu, Modified by Jingyun Liang.
-# -----------------------------------------------------------------------------------
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- flops = 0
- H, W = self.img_size
- if self.norm is not None:
- flops += H * W * self.embed_dim
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class SwinIR(nn.Module):
- r""" SwinIR
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(SwinIR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if self.upscale == 4:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- if self.upscale == 4:
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
-
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = SwinIR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
+# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(SwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = SwinIR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/modules/ui.py b/modules/ui.py index 87024238..15572bb0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -9,10 +9,13 @@ import random import sys
import time
import traceback
+import platform
+import subprocess as sp
import numpy as np
import torch
from PIL import Image, PngImagePlugin
+import piexif
import gradio as gr
import gradio.utils
@@ -61,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
-
+folder_symbol = '\uD83D\uDCC2'
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -111,18 +114,26 @@ def save_files(js_data, images, index): writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
filename_base = str(int(time.time() * 1000))
+ extension = opts.samples_format.lower()
for i, filedata in enumerate(images):
- filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png"
+ filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
filepath = os.path.join(opts.outdir_save, filename)
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
- pnginfo = PngImagePlugin.PngInfo()
- pnginfo.add_text('parameters', infotexts[i])
-
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
- image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
+ if opts.enable_pnginfo and extension == 'png':
+ pnginfo = PngImagePlugin.PngInfo()
+ pnginfo.add_text('parameters', infotexts[i])
+ image.save(filepath, pnginfo=pnginfo)
+ else:
+ image.save(filepath, quality=opts.jpeg_quality)
+
+ if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
+ piexif.insert(piexif.dump({"Exif": {
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
+ }}), filepath)
filenames.append(filename)
@@ -369,7 +380,7 @@ def create_toprow(is_img2img): with gr.Column(scale=1):
with gr.Row():
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- submit = gr.Button('Generate', elem_id="generate", variant='primary')
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
interrupt.click(
fn=lambda: shared.state.interrupt(),
@@ -461,6 +472,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): send_to_img2img = gr.Button('Send to img2img')
send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras')
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
+ open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group():
html_info = gr.HTML()
@@ -586,7 +599,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
@@ -637,6 +650,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras')
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
+ open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group():
html_info = gr.HTML()
@@ -809,6 +824,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): html_info = gr.HTML()
extras_send_to_img2img = gr.Button('Send to img2img')
extras_send_to_inpaint = gr.Button('Send to inpaint')
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
+ open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
fn=run_extras,
@@ -874,6 +891,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
@@ -907,6 +925,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): components = []
component_dict = {}
+ def open_folder(f):
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
def run_settings(*args):
changed = 0
@@ -1013,15 +1041,26 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): inputs=components,
outputs=[result, text_settings],
)
+
+ def modelmerger(*args):
+ try:
+ results = run_modelmerger(*args)
+ except Exception as e:
+ print("Error loading/saving model file:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ modules.sd_models.list_models() #To remove the potentially missing models from the list
+ return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
+ return results
modelmerger_merge.click(
- fn=run_modelmerger,
+ fn=modelmerger,
inputs=[
primary_model_name,
secondary_model_name,
interp_method,
interp_amount,
save_as_half,
+ custom_name,
],
outputs=[
submit_result,
@@ -1068,6 +1107,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[extras_image],
)
+ open_txt2img_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
+ inputs=[],
+ outputs=[],
+ )
+
+ open_img2img_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
+ inputs=[],
+ outputs=[],
+ )
+
+ open_extras_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
+ inputs=[],
+ outputs=[],
+ )
+
img2img_send_to_extras.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_extras",
diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 00000000..d9d7c5e2 --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,121 @@ +import os +from abc import abstractmethod + +import PIL +import numpy as np +import torch +from PIL import Image + +import modules.shared +from modules import modelloader, shared + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +from modules.paths import models_path + + +class Upscaler: + name = None + model_path = None + model_name = None + model_url = None + enable = True + filter = None + model = None + user_path = None + scalers: [] + tile = True + + def __init__(self, create_dirs=False): + self.mod_pad_h = None + self.tile_size = modules.shared.opts.ESRGAN_tile + self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap + self.device = modules.shared.device + self.img = None + self.output = None + self.scale = 1 + self.half = not modules.shared.cmd_opts.no_half + self.pre_pad = 0 + self.mod_scale = None + if self.name is not None and create_dirs: + self.model_path = os.path.join(models_path, self.name) + if not os.path.exists(self.model_path): + os.makedirs(self.model_path) + + try: + import cv2 + self.can_tile = True + except: + pass + + @abstractmethod + def do_upscale(self, img: PIL.Image, selected_model: str): + return img + + def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): + self.scale = scale + dest_w = img.width * scale + dest_h = img.height * scale + for i in range(3): + if img.width >= dest_w and img.height >= dest_h: + break + img = self.do_upscale(img, selected_model) + if img.width != dest_w or img.height != dest_h: + img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) + + return img + + @abstractmethod + def load_model(self, path: str): + pass + + def find_models(self, ext_filter=None) -> list: + return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) + + def update_status(self, prompt): + print(f"\nextras: {prompt}", file=shared.progress_print_out) + + +class UpscalerData: + name = None + data_path = None + scale: int = 4 + scaler: Upscaler = None + model: None + + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + self.name = name + self.data_path = path + self.scaler = upscaler + self.scale = scale + self.model = model + + +class UpscalerNone(Upscaler): + name = "None" + scalers = [] + + def load_model(self, path): + pass + + def do_upscale(self, img, selected_model=None): + return img + + def __init__(self, dirname=None): + super().__init__(False) + self.scalers = [UpscalerData("None", None, self)] + + +class UpscalerLanczos(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Lanczos" + self.scalers = [UpscalerData("Lanczos", None, self)] + diff --git a/requirements.txt b/requirements.txt index 0d9929ca..7cb9d329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ fairscale==0.4.4 fonts
font-roboto
gfpgan
-gradio
+gradio==3.4b3
invisible-watermark
numpy
omegaconf
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 9719bb8f..11613ca3 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -11,46 +11,8 @@ from modules import images, processing, devices from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
-# https://github.com/parlance-zz/g-diffuser-bot
-def expand(x, dir, amount, power=0.75):
- is_left = dir == 3
- is_right = dir == 1
- is_up = dir == 0
- is_down = dir == 2
-
- if is_left or is_right:
- noise = np.zeros((x.shape[0], amount, 3), dtype=float)
- indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
- if is_right:
- indexes = 1 - indexes
- indexes = (indexes * (x.shape[1] - 1)).astype(int)
-
- for row in range(x.shape[0]):
- if is_left:
- noise[row] = x[row][indexes[row]]
- else:
- noise[row] = np.flip(x[row][indexes[row]], axis=0)
-
- x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
- return x
-
- if is_up or is_down:
- noise = np.zeros((amount, x.shape[1], 3), dtype=float)
- indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
- if is_down:
- indexes = 1 - indexes
- indexes = (indexes * x.shape[0] - 1).astype(int)
-
- for row in range(x.shape[1]):
- if is_up:
- noise[:, row] = x[:, row][indexes[row]]
- else:
- noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
-
- x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
- return x
-
+# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index b87a145b..2653e2d4 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -34,7 +34,7 @@ class Script(scripts.Script): seed = p.seed
init_img = p.init_images[0]
- img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
+ img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
devices.torch_gc()
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 24fa5a0a..146663b0 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -45,11 +45,8 @@ def apply_sampler(p, x, xs): def apply_checkpoint(p, x, xs):
- applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
- assert len(applicable) > 0, f'Checkpoint {x} for found'
-
- info = applicable[0]
-
+ info = modules.sd_models.get_closet_checkpoint_match(x)
+ assert info is not None, f'Checkpoint for {x} not found'
modules.sd_models.reload_model_weights(shared.sd_model, info)
@@ -159,6 +156,9 @@ class Script(scripts.Script): p.batch_size = 1
def process_axis(opt, vals):
+ if opt.label == 'Nothing':
+ return [0]
+
valslist = [x.strip() for x in vals.split(",")]
if opt.type == int:
@@ -1,5 +1,11 @@ .output-html p {margin: 0 0.5em;}
+.row > *,
+.row > .gr-form > * {
+ min-width: min(120px, 100%);
+ flex: 1 1 0%;
+}
+
.performance {
font-size: 0.85em;
color: #444;
@@ -17,7 +23,7 @@ text-align: right;
}
-#generate{
+#txt2img_generate, #img2img_generate {
min-height: 4.5em;
}
@@ -43,13 +49,17 @@ margin-right: auto;
}
-#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{
+#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}
+#hidden_element{
+ display: none;
+}
+
#seed_row, #subseed_row{
gap: 0.5rem;
}
diff --git a/webui-user.sh b/webui-user.sh index 0ce41d3f..30646f5c 100644 --- a/webui-user.sh +++ b/webui-user.sh @@ -21,6 +21,9 @@ export COMMANDLINE_ARGS="" # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) #venv_dir="venv" +# script to launch to start the app +#export LAUNCH_SCRIPT="launch.py" + # install command for torch #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" @@ -1,37 +1,38 @@ import os
import threading
+from modules import devices
from modules.paths import script_path
-
import signal
-
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-import modules.ui
-import modules.scripts
-import modules.sd_hijack
-import modules.codeformer_model
-import modules.gfpgan_model
-import modules.face_restoration
-import modules.realesrgan_model as realesrgan
+import threading
+import modules.paths
+import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
-import modules.ldsr_model as ldsr
+import modules.bsrgan_model as bsrgan
import modules.extras
-import modules.lowvram
-import modules.txt2img
+import modules.face_restoration
+import modules.gfpgan_model as gfpgan
import modules.img2img
-import modules.swinir as swinir
+import modules.ldsr_model as ldsr
+import modules.lowvram
+import modules.realesrgan_model as realesrgan
+import modules.scripts
+import modules.sd_hijack
import modules.sd_models
+import modules.shared as shared
+import modules.swinir_model as swinir
+import modules.txt2img
+import modules.ui
+from modules import modelloader
+from modules.paths import script_path
+from modules.shared import cmd_opts
-
-modules.codeformer_model.setup_codeformer()
-modules.gfpgan_model.setup_gfpgan()
+modelloader.cleanup_models()
+modules.sd_models.setup_model(cmd_opts.ckpt_dir)
+codeformer.setup_model(cmd_opts.codeformer_models_path)
+gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
-
-esrgan.load_models(cmd_opts.esrgan_models_path)
-swinir.load_models(cmd_opts.swinir_models_path)
-realesrgan.setup_realesrgan()
-ldsr.add_lsdr()
+modelloader.load_upscalers()
queue_lock = threading.Lock()
@@ -47,6 +48,8 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs):
+ devices.torch_gc()
+
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
@@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func): shared.state.job = ""
shared.state.job_count = 0
+ devices.torch_gc()
+
return res
return modules.ui.wrap_gradio_call(f)
@@ -41,6 +41,11 @@ then venv_dir="venv" fi +if [[ -z "${LAUNCH_SCRIPT}" ]] +then + LAUNCH_SCRIPT="launch.py" +fi + # Disable sentry logging export ERROR_REPORTING=FALSE @@ -133,4 +138,4 @@ fi printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." printf "\n%s\n" "${delimiter}" -"${python_cmd}" launch.py +"${python_cmd}" "${LAUNCH_SCRIPT}" |