aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/lora.py5
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py7
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py8
3 files changed, 17 insertions, 3 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 137e58f7..cb8f1d36 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -166,7 +166,10 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ else:
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 60b9eb64..544b228d 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -3,7 +3,7 @@ import torch
import lora
import extra_networks_lora
import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
@@ -28,3 +28,8 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
+
+
+shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
+ "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
+}))
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 9a74b253..e8783bca 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts
+from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
for w_idx in w_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)