diff options
author | yfszzx <yfszzx@gmail.com> | 2022-10-17 07:23:32 +0000 |
---|---|---|
committer | yfszzx <yfszzx@gmail.com> | 2022-10-17 07:23:32 +0000 |
commit | 2a3e7ed872dc9b8da27cccc7f78df092f4a2f578 (patch) | |
tree | 5a288fe646eb58d35f4079b734a6d3765a98fd3f /modules | |
parent | 5b1394bead93e5485ced5de10f1c000eea4458c6 (diff) | |
parent | cccc5a20fce4bde9a4299f8790366790735f1d05 (diff) | |
download | stable-diffusion-webui-gfx803-2a3e7ed872dc9b8da27cccc7f78df092f4a2f578.tar.gz stable-diffusion-webui-gfx803-2a3e7ed872dc9b8da27cccc7f78df092f4a2f578.tar.bz2 stable-diffusion-webui-gfx803-2a3e7ed872dc9b8da27cccc7f78df092f4a2f578.zip |
Merge branch 'master' of https://github.com/yfszzx/stable-diffusion-webui-plus
Diffstat (limited to 'modules')
-rw-r--r-- | modules/extras.py | 28 | ||||
-rw-r--r-- | modules/interrogate.py | 8 | ||||
-rw-r--r-- | modules/scripts.py | 3 | ||||
-rw-r--r-- | modules/shared.py | 10 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 1 | ||||
-rw-r--r-- | modules/ui.py | 8 |
6 files changed, 43 insertions, 15 deletions
diff --git a/modules/extras.py b/modules/extras.py index 0819ed37..8dbab240 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -175,11 +175,14 @@ def run_pnginfo(image): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
- def weighted_sum(theta0, theta1, theta2, alpha):
+ def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
- def add_difference(theta0, theta1, theta2, alpha):
- return theta0 + (theta1 - theta2) * alpha
+ def get_difference(theta1, theta2):
+ return theta1 - theta2
+
+ def add_difference(theta0, theta1_2_diff, alpha):
+ return theta0 + (alpha * theta1_2_diff)
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
@@ -198,23 +201,28 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else:
+ teritary_model = None
theta_2 = None
theta_funcs = {
- "Weighted sum": weighted_sum,
- "Add difference": add_difference,
+ "Weighted sum": (None, weighted_sum),
+ "Add difference": (get_difference, add_difference),
}
- theta_func = theta_funcs[interp_method]
+ theta_func1, theta_func2 = theta_funcs[interp_method]
print(f"Merging...")
+ if theta_func1:
+ for key in tqdm.tqdm(theta_1.keys()):
+ if 'model' in key:
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
+ theta_1[key] = theta_func1(theta_1[key], t2)
+ del theta_2, teritary_model
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- t2 = (theta_2 or {}).get(key)
- if t2 is None:
- t2 = torch.zeros_like(theta_0[key])
- theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
+ theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
if save_as_half:
theta_0[key] = theta_0[key].half()
diff --git a/modules/interrogate.py b/modules/interrogate.py index 9263d65a..64b91eb4 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -123,7 +123,7 @@ class InterrogateModels: return caption[0]
- def interrogate(self, pil_image, include_ranks=False):
+ def interrogate(self, pil_image):
res = None
try:
@@ -156,10 +156,10 @@ class InterrogateModels: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- if include_ranks:
- res += ", " + match
+ if shared.opts.interrogate_return_ranks:
+ res += f", ({match}:{score/100:.3f})"
else:
- res += f", ({match}:{score})"
+ res += ", " + match
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/scripts.py b/modules/scripts.py index 45230f9a..ac66d448 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -58,6 +58,9 @@ def load_scripts(basedir): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename)
+ if os.path.splitext(path)[1].lower() != '.py':
+ continue
+
if not os.path.isfile(path):
continue
diff --git a/modules/shared.py b/modules/shared.py index 72513f86..c2ea4186 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -77,6 +77,16 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl cmd_opts = parser.parse_args()
+restricted_opts = [
+ "samples_filename_pattern",
+ "outdir_samples",
+ "outdir_txt2img_samples",
+ "outdir_img2img_samples",
+ "outdir_extras_samples",
+ "outdir_grids",
+ "outdir_txt2img_grids",
+ "outdir_save",
+]
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7ec75018..3be69562 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -137,6 +137,7 @@ class EmbeddingDatabase: continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ print("Embeddings:", ', '.join(self.word_embeddings.keys()))
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
diff --git a/modules/ui.py b/modules/ui.py index 7b0d5a92..43dc88fc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -25,7 +25,7 @@ import gradio.routes from modules import sd_hijack, sd_models
from modules.paths import script_path
-from modules.shared import opts, cmd_opts
+from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
@@ -1430,6 +1430,9 @@ Requested path was: {f} if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ continue
+
oldval = opts.data.get(key, None)
opts.data[key] = value
@@ -1447,6 +1450,9 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ return gr.update(value=oldval), opts.dumpjson()
+
oldval = opts.data.get(key, None)
opts.data[key] = value
|