diff options
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/LDSR/scripts/ldsr_model.py | 20 | ||||
-rw-r--r-- | extensions-builtin/Lora/extra_networks_lora.py | 2 | ||||
-rw-r--r-- | extensions-builtin/Lora/lora.py | 223 | ||||
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 24 | ||||
-rw-r--r-- | extensions-builtin/Lora/ui_extra_networks_lora.py | 14 | ||||
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 83 | ||||
-rw-r--r-- | extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js | 128 |
7 files changed, 326 insertions, 168 deletions
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index b8cff29b..da19cff1 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -25,22 +25,28 @@ class UpscalerLDSR(Upscaler): yaml_path = os.path.join(self.model_path, "project.yaml") old_model_path = os.path.join(self.model_path, "model.pth") new_model_path = os.path.join(self.model_path, "model.ckpt") - safetensors_model_path = os.path.join(self.model_path, "model.safetensors") + + local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"]) + local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None) + local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None) + local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None) + if os.path.exists(yaml_path): statinfo = os.stat(yaml_path) if statinfo.st_size >= 10485760: print("Removing invalid LDSR YAML file.") os.remove(yaml_path) + if os.path.exists(old_model_path): print("Renaming model from model.pth to model.ckpt") os.rename(old_model_path, new_model_path) - if os.path.exists(safetensors_model_path): - model = safetensors_model_path + + if local_safetensors_path is not None and os.path.exists(local_safetensors_path): + model = local_safetensors_path else: - model = load_file_from_url(url=self.model_url, model_dir=self.model_path, - file_name="model.ckpt", progress=True) - yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, - file_name="project.yaml", progress=True) + model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True) + + yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True) try: return LDSR(model, yaml) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 6be6ef73..45f899fc 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36..d3eb0d3b 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -2,18 +2,34 @@ import glob import os
import re
import torch
+from typing import Union
-from modules import shared, devices, sd_models
+from modules import shared, devices, sd_models, errors
-re_digits = re.compile(r"\d+")
-re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
-re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
-re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
-re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
+re_digits = re.compile(r"\d+")
+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
+re_compiled = {}
+
+suffix_conversion = {
+ "attentions": {},
+ "resnets": {
+ "conv1": "in_layers_2",
+ "conv2": "out_layers_3",
+ "time_emb_proj": "emb_layers_1",
+ "conv_shortcut": "skip_connection",
+ }
+}
+
+
+def convert_diffusers_name_to_compvis(key, is_sd2):
+ def match(match_list, regex_text):
+ regex = re_compiled.get(regex_text)
+ if regex is None:
+ regex = re.compile(regex_text)
+ re_compiled[regex_text] = regex
-def convert_diffusers_name_to_compvis(key):
- def match(match_list, regex):
r = re.match(regex, key)
if not r:
return False
@@ -24,16 +40,33 @@ def convert_diffusers_name_to_compvis(key): m = []
- if match(m, re_unet_down_blocks):
- return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
+ return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
- if match(m, re_unet_mid_blocks):
- return f"diffusion_model_middle_block_1_{m[1]}"
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
+ return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
- if match(m, re_unet_up_blocks):
- return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
+ return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
+
+ if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
+ if is_sd2:
+ if 'mlp_fc1' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
- if match(m, re_text_block):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key
@@ -43,6 +76,23 @@ class LoraOnDisk: def __init__(self, name, filename):
self.name = name
self.filename = filename
+ self.metadata = {}
+
+ _, ext = os.path.splitext(filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading lora {filename}")
+
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+
+ self.metadata = m
+
+ self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
class LoraModule:
@@ -82,15 +132,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename)
- keys_failed_to_match = []
+ keys_failed_to_match = {}
+ is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items():
- fullkey = convert_diffusers_name_to_compvis(key_diffusers)
- key, lora_key = fullkey.split(".", 1)
+ key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
+ key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
+
+ if sd_module is None:
+ m = re_x_proj.match(key)
+ if m:
+ sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
+
if sd_module is None:
- keys_failed_to_match.append(key_diffusers)
+ keys_failed_to_match[key_diffusers] = key
continue
lora_module = lora.modules.get(key, None)
@@ -104,15 +161,21 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif type(sd_module) == torch.nn.MultiheadAttention:
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
+ print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
+ continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
with torch.no_grad():
module.weight.copy_(weight)
- module.to(device=devices.device, dtype=devices.dtype)
+ module.to(device=devices.cpu, dtype=devices.dtype)
if lora_key == "lora_up.weight":
lora_module.up = module
@@ -158,28 +221,120 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora)
-def lora_forward(module, input, res):
- if len(loaded_loras) == 0:
- return res
+def lora_calc_updown(lora, module, target):
+ with torch.no_grad():
+ up = module.up.weight.to(target.device, dtype=target.dtype)
+ down = module.down.weight.to(target.device, dtype=target.dtype)
- lora_layer_name = getattr(module, 'lora_layer_name', None)
- for lora in loaded_loras:
- module = lora.modules.get(lora_layer_name, None)
- if module is not None:
- if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
- res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ else:
+ updown = up @ down
+
+ updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return updown
+
+
+def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ """
+ Applies the currently selected set of Loras to the weights of torch layer self.
+ If weights already have this particular set of loras applied, does nothing.
+ If not, restores orginal weights from backup and alters weights according to loras.
+ """
+
+ lora_layer_name = getattr(self, 'lora_layer_name', None)
+ if lora_layer_name is None:
+ return
+
+ current_names = getattr(self, "lora_current_names", ())
+ wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
+
+ weights_backup = getattr(self, "lora_weights_backup", None)
+ if weights_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
+ else:
+ weights_backup = self.weight.to(devices.cpu, copy=True)
+
+ self.lora_weights_backup = weights_backup
+
+ if current_names != wanted_names:
+ if weights_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
else:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ self.weight.copy_(weights_backup)
+
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is not None and hasattr(self, 'weight'):
+ self.weight += lora_calc_updown(lora, module, self.weight)
+ continue
+
+ module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
+ module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
+ module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
+ module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
+
+ if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
+ updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
+ updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
+ updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
+ continue
- return res
+ if module is None:
+ continue
+
+ print(f'failed to calculate lora weights for layer {lora_layer_name}')
+
+ setattr(self, "lora_current_names", wanted_names)
+
+
+def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
+ setattr(self, "lora_current_names", ())
+ setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input):
- return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
+ lora_apply_weights(self)
+
+ return torch.nn.Linear_forward_before_lora(self, input)
+
+
+def lora_Linear_load_state_dict(self, *args, **kwargs):
+ lora_reset_cached_weight(self)
+
+ return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
def lora_Conv2d_forward(self, input):
- return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
+ lora_apply_weights(self)
+
+ return torch.nn.Conv2d_forward_before_lora(self, input)
+
+
+def lora_Conv2d_load_state_dict(self, *args, **kwargs):
+ lora_reset_cached_weight(self)
+
+ return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
+
+
+def lora_MultiheadAttention_forward(self, *args, **kwargs):
+ lora_apply_weights(self)
+
+ return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
+
+
+def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
+ lora_reset_cached_weight(self)
+
+ return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
def list_available_loras():
@@ -192,7 +347,7 @@ def list_available_loras(): glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
- for filename in sorted(candidates):
+ for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160..3fc38ab9 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
+ torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
+ torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
+ torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
+ torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
def before_ui():
@@ -20,11 +24,27 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
+if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
+ torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
+
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
+if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
+ torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
+
+if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
+ torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
+
+if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
+ torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
+
torch.nn.Linear.forward = lora.lora_Linear_forward
+torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
+torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
+torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
+torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
@@ -32,7 +52,5 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
- "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
-
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 22cabcb0..68b11332 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def list_items(self):
for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
-
yield {
"name": name,
"filename": path,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": path + ".png",
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
}
def allowed_directories_for_previews(self):
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index e0fbf3a3..c7fd5739 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -5,11 +5,15 @@ import traceback import PIL.Image import numpy as np import torch +from tqdm import tqdm + from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader from scunet_model_arch import SCUNet as net +from modules.shared import opts +from modules import images class UpscalerScuNET(modules.upscaler.Upscaler): @@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - def do_upscale(self, img: PIL.Image, selected_file): + @staticmethod + @torch.no_grad() + def tiled_inference(img, model): + # test the image tile by tile + h, w = img.shape[2:] + tile = opts.SCUNET_tile + tile_overlap = opts.SCUNET_tile_overlap + if tile == 0: + return model(img) + + device = devices.get_device_for('scunet') + assert tile % 8 == 0, "tile size should be a multiple of window_size" + sf = 1 + + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) + + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: + for h_idx in h_idx_list: + + for w_idx in w_idx_list: + + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) + output = E.div_(W) + + return output + + def do_upscale(self, img: PIL.Image.Image, selected_file): + torch.cuda.empty_cache() model = self.load_model(selected_file) if model is None: + print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) return img device = devices.get_device_for('scunet') - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) - - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] + tile = opts.SCUNET_tile + h, w = img.height, img.width + np_img = np.array(img) + np_img = np_img[:, :, ::-1] # RGB to BGR + np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW + torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore + + if tile > h or tile > w: + _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) + _img[:, :, :h, :w] = torch_img # pad image + torch_img = _img + + torch_output = self.tiled_inference(torch_img, model).squeeze(0) + torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any + np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() + del torch_img, torch_output torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') + + output = np_output.transpose((1, 2, 0)) # CHW to HWC + output = output[:, :, ::-1] # BGR to RGB + return PIL.Image.fromarray((output * 255).astype(np.uint8)) def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler): model = model.to(device) return model - diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 4a85c8eb..5c7a836a 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -1,110 +1,42 @@ // Stable Diffusion WebUI - Bracket checker -// Version 1.0 -// By Hingashi no Florin/Bwin4L +// By Hingashi no Florin/Bwin4L & @akx // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. -function checkBrackets(evt, textArea, counterElt) { - errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; - errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; - errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; - - openBracketRegExp = /\(/g; - closeBracketRegExp = /\)/g; - - openSquareBracketRegExp = /\[/g; - closeSquareBracketRegExp = /\]/g; - - openCurlyBracketRegExp = /\{/g; - closeCurlyBracketRegExp = /\}/g; - - totalOpenBracketMatches = 0; - totalCloseBracketMatches = 0; - totalOpenSquareBracketMatches = 0; - totalCloseSquareBracketMatches = 0; - totalOpenCurlyBracketMatches = 0; - totalCloseCurlyBracketMatches = 0; - - openBracketMatches = textArea.value.match(openBracketRegExp); - if(openBracketMatches) { - totalOpenBracketMatches = openBracketMatches.length; - } - - closeBracketMatches = textArea.value.match(closeBracketRegExp); - if(closeBracketMatches) { - totalCloseBracketMatches = closeBracketMatches.length; - } - - openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp); - if(openSquareBracketMatches) { - totalOpenSquareBracketMatches = openSquareBracketMatches.length; - } - - closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp); - if(closeSquareBracketMatches) { - totalCloseSquareBracketMatches = closeSquareBracketMatches.length; - } - - openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp); - if(openCurlyBracketMatches) { - totalOpenCurlyBracketMatches = openCurlyBracketMatches.length; - } - - closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp); - if(closeCurlyBracketMatches) { - totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length; - } - - if(totalOpenBracketMatches != totalCloseBracketMatches) { - if(!counterElt.title.includes(errorStringParen)) { - counterElt.title += errorStringParen; +function checkBrackets(textArea, counterElt) { + var counts = {}; + (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => { + counts[bracket] = (counts[bracket] || 0) + 1; + }); + var errors = []; + + function checkPair(open, close, kind) { + if (counts[open] !== counts[close]) { + errors.push( + `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` + ); } - } else { - counterElt.title = counterElt.title.replace(errorStringParen, ''); } - if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) { - if(!counterElt.title.includes(errorStringSquare)) { - counterElt.title += errorStringSquare; - } - } else { - counterElt.title = counterElt.title.replace(errorStringSquare, ''); - } + checkPair('(', ')', 'round brackets'); + checkPair('[', ']', 'square brackets'); + checkPair('{', '}', 'curly brackets'); + counterElt.title = errors.join('\n'); + counterElt.classList.toggle('error', errors.length !== 0); +} - if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) { - if(!counterElt.title.includes(errorStringCurly)) { - counterElt.title += errorStringCurly; - } - } else { - counterElt.title = counterElt.title.replace(errorStringCurly, ''); - } +function setupBracketChecking(id_prompt, id_counter) { + var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); + var counter = gradioApp().getElementById(id_counter) - if(counterElt.title != '') { - counterElt.classList.add('error'); - } else { - counterElt.classList.remove('error'); + if (textarea && counter) { + textarea.addEventListener("input", () => checkBrackets(textarea, counter)); } } -function setupBracketChecking(id_prompt, id_counter){ - var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); - var counter = gradioApp().getElementById(id_counter) - textarea.addEventListener("input", function(evt){ - checkBrackets(evt, textarea, counter) - }); -} - -var shadowRootLoaded = setInterval(function() { - var shadowRoot = document.querySelector('gradio-app').shadowRoot; - if(! shadowRoot) return false; - - var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) return false; - - clearInterval(shadowRootLoaded); - - setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') - setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') - setupBracketChecking('img2img_prompt', 'imgimg_token_counter') - setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') -}, 1000); +onUiLoaded(function () { + setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); + setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); + setupBracketChecking('img2img_prompt', 'img2img_token_counter'); + setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); +}); |