aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-30 08:42:40 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-30 08:42:40 +0000
commitd1f098540ad1dbc2abb8d04322634efba650b631 (patch)
tree2c5cd3088177c938d643ed5430f7e2c38cdc2209
parent2552204fcb04c13af93749bd125d086ba148517c (diff)
downloadstable-diffusion-webui-gfx803-d1f098540ad1dbc2abb8d04322634efba650b631.tar.gz
stable-diffusion-webui-gfx803-d1f098540ad1dbc2abb8d04322634efba650b631.tar.bz2
stable-diffusion-webui-gfx803-d1f098540ad1dbc2abb8d04322634efba650b631.zip
remove unwanted formatting/functionality from the PR
-rw-r--r--launch.py7
-rw-r--r--modules/esrgan_model.py123
-rw-r--r--modules/extras.py35
-rw-r--r--modules/gfpgan_model.py12
-rw-r--r--modules/images.py37
-rw-r--r--modules/ldsr_model_arch.py1
-rw-r--r--modules/modelloader.py9
-rw-r--r--modules/realesrgan_model.py12
-rw-r--r--modules/sd_models.py56
-rw-r--r--modules/shared.py8
-rw-r--r--webui.py2
11 files changed, 127 insertions, 175 deletions
diff --git a/launch.py b/launch.py
index 3b8d8f23..d2793ed2 100644
--- a/launch.py
+++ b/launch.py
@@ -1,5 +1,4 @@
# this scripts installs necessary requirements and launches main program in webui.py
-import shutil
import subprocess
import os
import sys
@@ -119,11 +118,7 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
-if os.path.isdir(repo_dir('latent-diffusion')):
- try:
- shutil.rmtree(repo_dir('latent-diffusion'))
- except:
- pass
+
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index ce841aa4..ea91abfe 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -13,6 +13,63 @@ from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
+def fix_model_layers(crt_model, pretrained_net):
+ # this code is adapted from https://github.com/xinntao/ESRGAN
+ if 'conv_first.weight' in pretrained_net:
+ return pretrained_net
+
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
+ crt_net = crt_model.state_dict()
+ load_net_clean = {}
+ for k, v in pretrained_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ pretrained_net = load_net_clean
+
+ tbd = []
+ for k, v in crt_net.items():
+ tbd.append(k)
+
+ # directly copy
+ for k, v in crt_net.items():
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
+ crt_net[k] = pretrained_net[k]
+ tbd.remove(k)
+
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
+
+ for k in tbd.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[k] = pretrained_net[ori_k]
+ tbd.remove(k)
+
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+
+ return crt_net
+
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
@@ -28,14 +85,12 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- print(f"File: {file}")
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, self, 4)
- print(f"ESRGAN: Adding scaler {name}")
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
@@ -56,67 +111,14 @@ class UpscalerESRGAN(Upscaler):
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- # this code is adapted from https://github.com/xinntao/ESRGAN
+
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
- if 'conv_first.weight' in pretrained_net:
- crt_model.load_state_dict(pretrained_net)
- return crt_model
-
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
- "params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
-
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
- else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- crt_model.load_state_dict(crt_net)
+ pretrained_net = fix_model_layers(crt_model, pretrained_net)
+ crt_model.load_state_dict(pretrained_net)
crt_model.eval()
+
return crt_model
@@ -154,7 +156,6 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
- grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid)
return output
diff --git a/modules/extras.py b/modules/extras.py
index 1d4e9fa8..1bff5874 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -67,28 +67,29 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
- def upscale(image, scaler_index, resize):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
+ if upscaling_resize != 1.0:
+ def upscale(image, scaler_index, resize):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- cached_images[key] = c
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ cached_images[key] = c
- return c
+ return c
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
- image = res
+ image = res
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 2bf8a1ee..bb30d733 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -36,8 +36,7 @@ def gfpgann():
else:
print("Unable to load gfpgan model!")
return None
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
- bg_upsampler=None)
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
@@ -49,8 +48,7 @@ def gfpgan_fix_faces(np_image):
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
- only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
@@ -79,7 +77,6 @@ def setup_model(dirname):
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
- print("Setting model_dir to " + model_path)
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs):
@@ -92,7 +89,6 @@ def setup_model(dirname):
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
- print("Have gfpgan should be true?")
have_gfpgan = True
gfpgan_constructor = GFPGANer
@@ -102,9 +98,7 @@ def setup_model(dirname):
def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
- only_center_face=False,
- paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
diff --git a/modules/images.py b/modules/images.py
index 6430cfec..e89c44b2 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -84,10 +84,8 @@ def combine_grid(grid):
r = r.astype(np.uint8)
return Image.fromarray(r, 'L')
- mask_w = make_mask_image(
- np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
- mask_h = make_mask_image(
- np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles:
@@ -130,12 +128,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
def draw_texts(drawing, draw_x, draw_y, lines):
for i, line in enumerate(lines):
- drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt,
- fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
- drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2,
- draw_y + line.size[1] // 2), fill=color_inactive, width=4)
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing
@@ -206,10 +202,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
prompts_horiz = prompts[:boundary]
prompts_vert = prompts[boundary:]
- hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in
- range(1 << len(prompts_horiz))]
- ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
- range(1 << len(prompts_vert))]
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
@@ -259,13 +253,11 @@ def resize_image(resize_mode, im, width, height):
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
- box=(0, fill_height + src_h))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
- box=(fill_width + src_w, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
@@ -300,8 +292,7 @@ def apply_filename_pattern(x, p, seed, prompt):
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
if len(words) == 0:
words = ["empty"]
- x = x.replace("[prompt_words]",
- sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
+ x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None:
x = x.replace("[steps]", str(p.steps))
@@ -309,8 +300,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
- x = x.replace("[sampler]",
- sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+ x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
@@ -336,8 +326,7 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split(
- '-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
result = max(int(l[0]), result)
except ValueError:
@@ -346,9 +335,7 @@ def get_next_sequence_number(path, basename):
return result + 1
-def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False,
- no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
- forced_filename=None, suffix=""):
+def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None:
file_decoration = ""
elif opts.save_to_dirs:
diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py
index f8f3c3d3..7faac6e1 100644
--- a/modules/ldsr_model_arch.py
+++ b/modules/ldsr_model_arch.py
@@ -125,7 +125,6 @@ class LDSR:
del model
gc.collect()
torch.cuda.empty_cache()
- print(f'Processing finished!')
return a
diff --git a/modules/modelloader.py b/modules/modelloader.py
index b3e6dc36..1106aeb7 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -25,8 +25,10 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if ext_filter is None:
ext_filter = []
+
try:
places = []
+
if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path):
@@ -34,7 +36,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
places.append(pretrained_path)
elif os.path.exists(command_path):
places.append(command_path)
+
places.append(model_path)
+
for place in places:
if os.path.exists(place):
for file in os.listdir(place):
@@ -47,14 +51,17 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
continue
if file not in output:
output.append(full_path)
+
if model_url is not None and len(output) == 0:
if download_name is not None:
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
output.append(model_url)
- except:
+
+ except Exception:
pass
+
return output
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 0a2eb896..dc0123e0 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -88,28 +88,24 @@ def get_realesrgan_models(scaler):
models = [
UpscalerData(
name="R-ESRGAN General 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3"
- ".pth",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
scale=4,
upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
- act_type='prelu')
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
),
UpscalerData(
name="R-ESRGAN General WDN 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
scale=4,
upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
- act_type='prelu')
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
),
UpscalerData(
name="R-ESRGAN AnimeVideo",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
scale=4,
upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4,
- act_type='prelu')
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
),
UpscalerData(
name="R-ESRGAN 4x+",
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 4b9000a4..caa85d5e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -12,10 +12,10 @@ from modules import shared, modelloader
from modules.paths import models_path
model_dir = "Stable-diffusion"
-model_path = os.path.join(models_path, model_dir)
+model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
-user_dir = None
+user_dir: (str | None) = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
@@ -30,26 +30,8 @@ except Exception:
pass
-def modeltitle(path, h):
- abspath = os.path.abspath(path)
-
- if abspath.startswith(model_dir):
- name = abspath.replace(model_dir, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- return f'{name} [{h}]'
-
-
def setup_model(dirname):
- global model_path
- global model_name
- global model_url
global user_dir
- global model_list
user_dir = dirname
if not os.path.exists(model_path):
os.makedirs(model_path)
@@ -62,21 +44,16 @@ def checkpoint_tiles():
def list_models():
- global model_path
- global model_url
- global model_name
- global user_dir
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir,
- ext_filter=[".ckpt"], download_name=model_name)
- print(f"Model list: {model_list}")
- model_dir = os.path.abspath(model_path)
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
- def modeltitle(path, h):
+ def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
- if abspath.startswith(model_dir):
- name = abspath.replace(model_dir, '')
+ if user_dir is not None and abspath.startswith(user_dir):
+ name = abspath.replace(user_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)
@@ -85,29 +62,30 @@ def list_models():
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
- return f'{name} [{h}]', shortname
+ return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
- title, model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
+ title, short_model_name = modeltitle(cmd_ckpt, h)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
- title, model_name = modeltitle(filename, h)
- checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
+ title, short_model_name = modeltitle(filename, h)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+
def get_closet_checkpoint_match(searchString):
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable)>0:
+ if len(applicable) > 0:
return applicable[0]
return None
+
def model_hash(filename):
try:
- print(f"Opening: {filename}")
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
@@ -128,7 +106,7 @@ def select_checkpoint():
if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr)
+ print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
diff --git a/modules/shared.py b/modules/shared.py
index 69002158..03a1a4d3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -21,8 +21,7 @@ model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
-# This should be deprecated, but we'll leave it for a few iterations
-parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints (Deprecated, use '--stablediffusion-models-path'", )
+parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@@ -41,7 +40,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
-parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
@@ -61,10 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
cmd_opts = parser.parse_args()
-if cmd_opts.ckpt_dir is not None:
- print("The 'ckpt-dir' arg is deprecated in favor of the 'stablediffusion-models-path' argument and will be "
- "removed in a future release. Please use the new option if you wish to use a custom checkpoint directory.")
- cmd_opts.__setattr__("stablediffusion-models-path", cmd_opts.ckpt_dir)
device = get_optimal_device()
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
diff --git a/webui.py b/webui.py
index 5fd65edc..b8cccd54 100644
--- a/webui.py
+++ b/webui.py
@@ -28,7 +28,7 @@ from modules.paths import script_path
from modules.shared import cmd_opts
modelloader.cleanup_models()
-modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path)
+modules.sd_models.setup_model(cmd_opts.ckpt_dir)
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())