aboutsummaryrefslogtreecommitdiffstats
path: root/modules/extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/extras.py')
-rw-r--r--modules/extras.py73
1 files changed, 71 insertions, 2 deletions
diff --git a/modules/extras.py b/modules/extras.py
index 382ffa7d..6a0d5cb0 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -3,13 +3,17 @@ import os
import numpy as np
from PIL import Image
-from modules import processing, shared, images, devices
+import torch
+import tqdm
+
+from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
import piexif
import piexif.helper
+import gradio as gr
cached_images = {}
@@ -36,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs = []
for image, image_name in zip(imageArr, imageNameArr):
+ if image is None:
+ return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -70,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
c = cached_images.get(key)
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.upscale(image, image.width * resize, image.height * resize)
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c
return c
@@ -135,3 +141,66 @@ def run_pnginfo(image):
info = f"<div><p>{message}<p></div>"
return '', geninfo, info
+
+
+def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
+ # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
+ def weighted_sum(theta0, theta1, alpha):
+ return ((1 - alpha) * theta0) + (alpha * theta1)
+
+ # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ def sigmoid(theta0, theta1, alpha):
+ alpha = alpha * alpha * (3 - (2 * alpha))
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ def inv_sigmoid(theta0, theta1, alpha):
+ import math
+ alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
+
+ print(f"Loading {primary_model_info.filename}...")
+ primary_model = torch.load(primary_model_info.filename, map_location='cpu')
+
+ print(f"Loading {secondary_model_info.filename}...")
+ secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
+
+ theta_0 = primary_model['state_dict']
+ theta_1 = secondary_model['state_dict']
+
+ theta_funcs = {
+ "Weighted Sum": weighted_sum,
+ "Sigmoid": sigmoid,
+ "Inverse Sigmoid": inv_sigmoid,
+ }
+ theta_func = theta_funcs[interp_method]
+
+ print(f"Merging...")
+ for key in tqdm.tqdm(theta_0.keys()):
+ if 'model' in key and key in theta_1:
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ if save_as_half:
+ theta_0[key] = theta_0[key].half()
+
+ for key in theta_1.keys():
+ if 'model' in key and key not in theta_0:
+ theta_0[key] = theta_1[key]
+ if save_as_half:
+ theta_0[key] = theta_0[key].half()
+
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+
+ filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
+ filename = filename if custom_name == '' else (custom_name + '.ckpt')
+ output_modelname = os.path.join(ckpt_dir, filename)
+
+ print(f"Saving to {output_modelname}...")
+ torch.save(primary_model, output_modelname)
+
+ sd_models.list_models()
+
+ print(f"Checkpoint saved.")
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]