From 0fc7dc1c04a046d95588651ffc4e71a7d40378d3 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 04:01:13 -0600 Subject: implementing script metadata and DAG sorting mechanism --- modules/scripts.py | 141 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 126 insertions(+), 15 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 5c6e0226..e92a34a0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -2,6 +2,7 @@ import os import re import sys import inspect +from graphlib import TopologicalSorter, CycleError from collections import namedtuple from dataclasses import dataclass @@ -314,15 +315,131 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi def list_scripts(scriptdirname, extension, *, include_extensions=True): scripts_list = [] - - basedir = os.path.join(paths.script_path, scriptdirname) - if os.path.exists(basedir): - for filename in sorted(os.listdir(basedir)): - scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + script_dependency_map = {} + + # build script dependency map + + root_script_basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(root_script_basedir): + for filename in sorted(os.listdir(root_script_basedir)): + script_dependency_map[filename] = { + "extension": None, + "extension_dirname": None, + "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)), + "requires": [], + "load_before": [], + "load_after": [], + } if include_extensions: for ext in extensions.active(): - scripts_list += ext.list_files(scriptdirname, extension) + extension_scripts_list = ext.list_files(scriptdirname, extension) + for extension_script in extension_scripts_list: + # this is built on the assumption that script name is unique. + # I think bad thing is gonna happen if name collide in the current implementation anyway, but we + # will need to refactor here if this assumption is broken later on. + if extension_script.filename in script_dependency_map: + errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions " + f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". " + f"The current loading file will be discarded.", exc_info=False) + continue + + relative_path = scriptdirname + "/" + extension_script.filename + + requires = None + load_before = None + load_after = None + + if ext.metadata is not None: + requires = ext.metadata.get(relative_path, "Requires", fallback=None) + load_before = ext.metadata.get(relative_path, "Before", fallback=None) + load_after = ext.metadata.get(relative_path, "After", fallback=None) + + requires = [x.strip() for x in requires.split(',')] if requires else [] + load_after = [x.strip() for x in load_after.split(',')] if load_after else [] + load_before = [x.strip() for x in load_before.split(',')] if load_before else [] + + script_dependency_map[extension_script.filename] = { + "extension": ext.canonical_name, + "extension_dirname": ext.name, + "script_file": extension_script, + "requires": requires, + "load_before": load_before, + "load_after": load_after, + } + + # resolve dependencies + + loaded_extensions = set() + for _, script_data in script_dependency_map.items(): + if script_data['extension'] is not None: + loaded_extensions.add(script_data['extension']) + + for script_filename, script_data in script_dependency_map.items(): + # load before requires inverse dependency + # in this case, append the script name into the load_after list of the specified script + for load_before_script in script_data['load_before']: + if load_before_script.startswith('ext:'): + # if this requires an extension to be loaded before + required_extension = load_before_script[4:] + for _, script_data2 in script_dependency_map.items(): + if script_data2['extension'] == required_extension: + script_data2['load_after'].append(script_filename) + break + else: + # if this requires an individual script to be loaded before + if load_before_script in script_dependency_map: + script_dependency_map[load_before_script]['load_after'].append(script_filename) + + # resolve extension name in load_after lists + for load_after_script in script_data['load_after']: + if load_after_script.startswith('ext:'): + # if this requires an extension to be loaded after + required_extension = load_after_script[4:] + for script_file_name2, script_data2 in script_dependency_map.items(): + if script_data2['extension'] == required_extension: + script_data['load_after'].append(script_file_name2) + + # remove all extension names in load_after lists + script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')] + + # build the DAG + sorter = TopologicalSorter() + for script_filename, script_data in script_dependency_map.items(): + requirement_met = True + for required_script in script_data['requires']: + if required_script.startswith('ext:'): + # if this requires an extension to be installed + required_extension = required_script[4:] + if required_extension not in loaded_extensions: + errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to " + f"be loaded, but it is not. Skipping.", + exc_info=False) + requirement_met = False + break + else: + # if this requires an individual script to be loaded + if required_script not in script_dependency_map: + errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to " + f"be loaded, but it is not. Skipping.", + exc_info=False) + requirement_met = False + break + if not requirement_met: + continue + + sorter.add(script_filename, *script_data['load_after']) + + # sort the scripts + try: + ordered_script = sorter.static_order() + except CycleError: + errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True) + ordered_script = script_dependency_map.keys() + + for script_filename in ordered_script: + script_data = script_dependency_map[script_filename] + scripts_list.append(script_data['script_file']) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -365,15 +482,9 @@ def load_scripts(): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) - def orderby(basedir): - # 1st webui, 2nd extensions-builtin, 3rd extensions - priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} - for key in priority: - if basedir.startswith(key): - return priority[key] - return 9999 - - for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]): + # here the scripts_list is already ordered + # processing_script is not considered though + for scriptfile in scripts_list: try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path -- cgit v1.2.3 From 0d1924c48be3d02650e87b12a4f53165a8b4a599 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 04:03:55 -0600 Subject: populate loaded_extensions from extension list instead --- modules/scripts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index e92a34a0..7cdf288d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -371,9 +371,8 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): # resolve dependencies loaded_extensions = set() - for _, script_data in script_dependency_map.items(): - if script_data['extension'] is not None: - loaded_extensions.add(script_data['extension']) + for ext in extensions.active(): + loaded_extensions.add(ext.canonical_name) for script_filename, script_data in script_dependency_map.items(): # load before requires inverse dependency -- cgit v1.2.3 From 7af576e745c79a9539e40bc158e695192ae79f25 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 10:46:47 -0600 Subject: remove the assumption of same name --- modules/scripts.py | 81 ++++++++++++++++++++---------------------------------- 1 file changed, 30 insertions(+), 51 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 7cdf288d..7ad22245 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -335,15 +335,9 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): for ext in extensions.active(): extension_scripts_list = ext.list_files(scriptdirname, extension) for extension_script in extension_scripts_list: - # this is built on the assumption that script name is unique. - # I think bad thing is gonna happen if name collide in the current implementation anyway, but we - # will need to refactor here if this assumption is broken later on. - if extension_script.filename in script_dependency_map: - errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions " - f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". " - f"The current loading file will be discarded.", exc_info=False) - continue - + script_canonical_name = ext.canonical_name + "/" + extension_script.filename + if ext.is_builtin: + script_canonical_name = "builtin/" + script_canonical_name relative_path = scriptdirname + "/" + extension_script.filename requires = None @@ -359,7 +353,7 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): load_after = [x.strip() for x in load_after.split(',')] if load_after else [] load_before = [x.strip() for x in load_before.split(',')] if load_before else [] - script_dependency_map[extension_script.filename] = { + script_dependency_map[script_canonical_name] = { "extension": ext.canonical_name, "extension_dirname": ext.name, "script_file": extension_script, @@ -374,60 +368,45 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): for ext in extensions.active(): loaded_extensions.add(ext.canonical_name) - for script_filename, script_data in script_dependency_map.items(): + for script_canonical_name, script_data in script_dependency_map.items(): # load before requires inverse dependency # in this case, append the script name into the load_after list of the specified script for load_before_script in script_data['load_before']: - if load_before_script.startswith('ext:'): - # if this requires an extension to be loaded before - required_extension = load_before_script[4:] + # if this requires an individual script to be loaded before + if load_before_script in script_dependency_map: + script_dependency_map[load_before_script]['load_after'].append(script_canonical_name) + elif load_before_script in loaded_extensions: for _, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == required_extension: - script_data2['load_after'].append(script_filename) + if script_data2['extension'] == load_before_script: + script_data2['load_after'].append(script_canonical_name) break - else: - # if this requires an individual script to be loaded before - if load_before_script in script_dependency_map: - script_dependency_map[load_before_script]['load_after'].append(script_filename) # resolve extension name in load_after lists - for load_after_script in script_data['load_after']: - if load_after_script.startswith('ext:'): - # if this requires an extension to be loaded after - required_extension = load_after_script[4:] - for script_file_name2, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == required_extension: - script_data['load_after'].append(script_file_name2) - - # remove all extension names in load_after lists - script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')] + for load_after_script in list(script_data['load_after']): + if load_after_script not in script_dependency_map and load_after_script in loaded_extensions: + script_data['load_after'].remove(load_after_script) + for script_canonical_name2, script_data2 in script_dependency_map.items(): + if script_data2['extension'] == load_after_script: + script_data['load_after'].remove(script_canonical_name2) + break # build the DAG sorter = TopologicalSorter() - for script_filename, script_data in script_dependency_map.items(): + for script_canonical_name, script_data in script_dependency_map.items(): requirement_met = True for required_script in script_data['requires']: - if required_script.startswith('ext:'): - # if this requires an extension to be installed - required_extension = required_script[4:] - if required_extension not in loaded_extensions: - errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to " - f"be loaded, but it is not. Skipping.", - exc_info=False) - requirement_met = False - break - else: - # if this requires an individual script to be loaded - if required_script not in script_dependency_map: - errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to " - f"be loaded, but it is not. Skipping.", - exc_info=False) - requirement_met = False - break + # if this requires an individual script to be loaded + if required_script not in script_dependency_map and required_script not in loaded_extensions: + errors.report(f"Script \"{script_canonical_name}\" " + f"requires \"{required_script}\" to " + f"be loaded, but it is not. Skipping.", + exc_info=False) + requirement_met = False + break if not requirement_met: continue - sorter.add(script_filename, *script_data['load_after']) + sorter.add(script_canonical_name, *script_data['load_after']) # sort the scripts try: @@ -436,8 +415,8 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True) ordered_script = script_dependency_map.keys() - for script_filename in ordered_script: - script_data = script_dependency_map[script_filename] + for script_canonical_name in ordered_script: + script_data = script_dependency_map[script_canonical_name] scripts_list.append(script_data['script_file']) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] -- cgit v1.2.3 From 520e52f846892cc2b207b738b4180fa863c7b38f Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 10:58:26 -0600 Subject: allow comma and whitespace as separator --- modules/extensions.py | 9 ++++++--- modules/scripts.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/extensions.py b/modules/extensions.py index 7583a3b0..795af996 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -2,6 +2,7 @@ import configparser import functools import os import threading +import re from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo @@ -48,7 +49,8 @@ class Extension: config.read(os.path.join(self.path, "sd_webui_metadata.ini")) return config except Exception: - errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True) + errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", + exc_info=True) return None def to_dict(self): @@ -70,6 +72,7 @@ class Extension: self.do_read_info_from_repo() return self.to_dict() + try: d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) self.from_dict(d) @@ -194,8 +197,8 @@ def list_extensions(): f"The current loading extension will be discarded.", exc_info=False) continue - # we want to wash the data to lowercase and remove whitespaces just in case - requires = [x.strip() for x in requires.lower().split(',')] if requires else [] + # both "," and " " are accepted as separator + requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] extension_dependency_map[canonical_name] = { "dirname": extension_dirname, diff --git a/modules/scripts.py b/modules/scripts.py index 7ad22245..5dd0555d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -349,9 +349,9 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): load_before = ext.metadata.get(relative_path, "Before", fallback=None) load_after = ext.metadata.get(relative_path, "After", fallback=None) - requires = [x.strip() for x in requires.split(',')] if requires else [] - load_after = [x.strip() for x in load_after.split(',')] if load_after else [] - load_before = [x.strip() for x in load_before.split(',')] if load_before else [] + requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] + load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else [] + load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else [] script_dependency_map[script_canonical_name] = { "extension": ext.canonical_name, -- cgit v1.2.3 From 3bb32befe9523a6acefbab7fe099f91660f41ea9 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 11:58:19 -0600 Subject: bug fix --- modules/scripts.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 5dd0555d..b1f4504a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -322,6 +322,9 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): root_script_basedir = os.path.join(paths.script_path, scriptdirname) if os.path.exists(root_script_basedir): for filename in sorted(os.listdir(root_script_basedir)): + if not os.path.isfile(os.path.join(root_script_basedir, filename)): + continue + script_dependency_map[filename] = { "extension": None, "extension_dirname": None, @@ -335,19 +338,27 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): for ext in extensions.active(): extension_scripts_list = ext.list_files(scriptdirname, extension) for extension_script in extension_scripts_list: + if not os.path.isfile(extension_script.path): + continue + script_canonical_name = ext.canonical_name + "/" + extension_script.filename if ext.is_builtin: script_canonical_name = "builtin/" + script_canonical_name relative_path = scriptdirname + "/" + extension_script.filename - requires = None - load_before = None - load_after = None + requires = '' + load_before = '' + load_after = '' if ext.metadata is not None: - requires = ext.metadata.get(relative_path, "Requires", fallback=None) - load_before = ext.metadata.get(relative_path, "Before", fallback=None) - load_after = ext.metadata.get(relative_path, "After", fallback=None) + requires = ext.metadata.get(relative_path, "Requires", fallback='') + load_before = ext.metadata.get(relative_path, "Before", fallback='') + load_after = ext.metadata.get(relative_path, "After", fallback='') + + # propagate directory level metadata + requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='') + load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='') + load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='') requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else [] @@ -387,7 +398,7 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): script_data['load_after'].remove(load_after_script) for script_canonical_name2, script_data2 in script_dependency_map.items(): if script_data2['extension'] == load_after_script: - script_data['load_after'].remove(script_canonical_name2) + script_data['load_after'].append(script_canonical_name2) break # build the DAG -- cgit v1.2.3 From 9b471436b2226458a767077707ea102e331b5d78 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 20 Nov 2023 14:47:09 +0300 Subject: rework extensions metadata: use custom sorter that doesn't mess the order as much and ignores cyclic errors, use classes with named fields instead of dictionaries, eliminate some duplicated code --- modules/extensions.py | 132 +++++++++++++++++++++------------------ modules/scripts.py | 169 +++++++++++++++++++++++--------------------------- 2 files changed, 148 insertions(+), 153 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/extensions.py b/modules/extensions.py index f3988d02..1899cd52 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import configparser -import functools import os import threading import re @@ -8,7 +9,6 @@ from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 -extensions = [] os.makedirs(extensions_dir, exist_ok=True) @@ -22,13 +22,56 @@ def active(): return [x for x in extensions if x.enabled] +class ExtensionMetadata: + filename = "metadata.ini" + config: configparser.ConfigParser + canonical_name: str + requires: list + + def __init__(self, path, canonical_name): + self.config = configparser.ConfigParser() + + filepath = os.path.join(path, self.filename) + if os.path.isfile(filepath): + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + + self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) + self.canonical_name = canonical_name.lower().strip() + + self.requires = self.get_script_requirements("Requires", "Extension") + + def get_script_requirements(self, field, section, extra_section=None): + """reads a list of requirements from the config; field is the name of the field in the ini file, + like Requires or Before, and section is the name of the [section] in the ini file; additionally, + reads more requirements from [extra_section] if specified.""" + + x = self.config.get(section, field, fallback='') + + if extra_section: + x = x + ', ' + self.config.get(extra_section, field, fallback='') + + return self.parse_list(x.lower()) + + def parse_list(self, text): + """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" + + if not text: + return [] + + # both "," and " " are accepted as separator + return [x for x in re.split(r"[,\s]+", text.strip()) if x] + + class Extension: lock = threading.Lock() cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] + metadata: ExtensionMetadata - def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None): + def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None): self.name = name - self.canonical_name = canonical_name or name.lower() self.path = path self.enabled = enabled self.status = '' @@ -40,18 +83,8 @@ class Extension: self.branch = None self.remote = None self.have_info_from_repo = False - - @functools.cached_property - def metadata(self): - if os.path.isfile(os.path.join(self.path, "metadata.ini")): - try: - config = configparser.ConfigParser() - config.read(os.path.join(self.path, "metadata.ini")) - return config - except Exception: - errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.", - exc_info=True) - return None + self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) + self.canonical_name = metadata.canonical_name def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} @@ -162,7 +195,7 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - extension_dependency_map = {} + loaded_extensions = {} # scan through extensions directory and load metadata for dirname in [extensions_builtin_dir, extensions_dir]: @@ -175,55 +208,30 @@ def list_extensions(): continue canonical_name = extension_dirname - requires = None + metadata = ExtensionMetadata(path, canonical_name) - if os.path.isfile(os.path.join(path, "metadata.ini")): - try: - config = configparser.ConfigParser() - config.read(os.path.join(path, "metadata.ini")) - canonical_name = config.get("Extension", "Name", fallback=canonical_name) - requires = config.get("Extension", "Requires", fallback=None) - except Exception: - errors.report(f"Error reading metadata.ini for extension {extension_dirname}. " - f"Will load regardless.", exc_info=True) + # check for duplicated canonical names + already_loaded_extension = loaded_extensions.get(metadata.canonical_name) + if already_loaded_extension is not None: + errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False) + continue - canonical_name = canonical_name.lower().strip() + is_builtin = dirname == extensions_builtin_dir + extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) + extensions.append(extension) + loaded_extensions[canonical_name] = extension - # check for duplicated canonical names - if canonical_name in extension_dependency_map: - errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions " - f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". " - f"The current loading extension will be discarded.", exc_info=False) + # check for requirements + for extension in extensions: + for req in extension.metadata.requires: + required_extension = loaded_extensions.get(req) + if required_extension is None: + errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) continue - # both "," and " " are accepted as separator - requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] + if not extension.enabled: + errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) + continue - extension_dependency_map[canonical_name] = { - "dirname": extension_dirname, - "path": path, - "requires": requires, - } - # check for requirements - for (_, extension_data) in extension_dependency_map.items(): - dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires'] - requirement_met = True - for req in requires: - if req not in extension_dependency_map: - errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. " - f"The current loading extension will be discarded.", exc_info=False) - requirement_met = False - break - dep_dirname = extension_dependency_map[req]['dirname'] - if dep_dirname in shared.opts.disabled_extensions: - errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. " - f"The current loading extension will be discarded.", exc_info=False) - requirement_met = False - break - - is_builtin = dirname == extensions_builtin_dir - extension = Extension(name=dirname, path=path, - enabled=dirname not in shared.opts.disabled_extensions and requirement_met, - is_builtin=is_builtin) - extensions.append(extension) +extensions: list[Extension] = [] diff --git a/modules/scripts.py b/modules/scripts.py index b1f4504a..b0689a23 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -2,7 +2,6 @@ import os import re import sys import inspect -from graphlib import TopologicalSorter, CycleError from collections import namedtuple from dataclasses import dataclass @@ -312,27 +311,57 @@ scripts_data = [] postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) +def topological_sort(dependencies): + """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. + Ignores errors relating to missing dependeencies or circular dependencies + """ + + visited = {} + result = [] + + def inner(name): + visited[name] = True + + for dep in dependencies.get(name, []): + if dep in dependencies and dep not in visited: + inner(dep) + + result.append(name) + + for depname in dependencies: + if depname not in visited: + inner(depname) + + return result + + +@dataclass +class ScriptWithDependencies: + script_canonical_name: str + file: ScriptFile + requires: list + load_before: list + load_after: list + def list_scripts(scriptdirname, extension, *, include_extensions=True): - scripts_list = [] - script_dependency_map = {} + scripts = {} - # build script dependency map + loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()} + loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()} + # build script dependency map root_script_basedir = os.path.join(paths.script_path, scriptdirname) if os.path.exists(root_script_basedir): for filename in sorted(os.listdir(root_script_basedir)): if not os.path.isfile(os.path.join(root_script_basedir, filename)): continue - script_dependency_map[filename] = { - "extension": None, - "extension_dirname": None, - "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)), - "requires": [], - "load_before": [], - "load_after": [], - } + if os.path.splitext(filename)[1].lower() != extension: + continue + + script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)) + scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], []) if include_extensions: for ext in extensions.active(): @@ -341,96 +370,54 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): if not os.path.isfile(extension_script.path): continue - script_canonical_name = ext.canonical_name + "/" + extension_script.filename - if ext.is_builtin: - script_canonical_name = "builtin/" + script_canonical_name + script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename relative_path = scriptdirname + "/" + extension_script.filename - requires = '' - load_before = '' - load_after = '' - - if ext.metadata is not None: - requires = ext.metadata.get(relative_path, "Requires", fallback='') - load_before = ext.metadata.get(relative_path, "Before", fallback='') - load_after = ext.metadata.get(relative_path, "After", fallback='') - - # propagate directory level metadata - requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='') - load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='') - load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='') - - requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] - load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else [] - load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else [] - - script_dependency_map[script_canonical_name] = { - "extension": ext.canonical_name, - "extension_dirname": ext.name, - "script_file": extension_script, - "requires": requires, - "load_before": load_before, - "load_after": load_after, - } + script = ScriptWithDependencies( + script_canonical_name=script_canonical_name, + file=extension_script, + requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname), + load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname), + load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname), + ) - # resolve dependencies + scripts[script_canonical_name] = script + loaded_extensions_scripts[ext.canonical_name].append(script) - loaded_extensions = set() - for ext in extensions.active(): - loaded_extensions.add(ext.canonical_name) - - for script_canonical_name, script_data in script_dependency_map.items(): + for script_canonical_name, script in scripts.items(): # load before requires inverse dependency # in this case, append the script name into the load_after list of the specified script - for load_before_script in script_data['load_before']: + for load_before in script.load_before: # if this requires an individual script to be loaded before - if load_before_script in script_dependency_map: - script_dependency_map[load_before_script]['load_after'].append(script_canonical_name) - elif load_before_script in loaded_extensions: - for _, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == load_before_script: - script_data2['load_after'].append(script_canonical_name) - break - - # resolve extension name in load_after lists - for load_after_script in list(script_data['load_after']): - if load_after_script not in script_dependency_map and load_after_script in loaded_extensions: - script_data['load_after'].remove(load_after_script) - for script_canonical_name2, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == load_after_script: - script_data['load_after'].append(script_canonical_name2) - break - - # build the DAG - sorter = TopologicalSorter() - for script_canonical_name, script_data in script_dependency_map.items(): - requirement_met = True - for required_script in script_data['requires']: - # if this requires an individual script to be loaded - if required_script not in script_dependency_map and required_script not in loaded_extensions: - errors.report(f"Script \"{script_canonical_name}\" " - f"requires \"{required_script}\" to " - f"be loaded, but it is not. Skipping.", - exc_info=False) - requirement_met = False - break - if not requirement_met: - continue + other_script = scripts.get(load_before) + if other_script: + other_script.load_after.append(script_canonical_name) - sorter.add(script_canonical_name, *script_data['load_after']) + # if this requires an extension + other_extension_scripts = loaded_extensions_scripts.get(load_before) + if other_extension_scripts: + for other_script in other_extension_scripts: + other_script.load_after.append(script_canonical_name) - # sort the scripts - try: - ordered_script = sorter.static_order() - except CycleError: - errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True) - ordered_script = script_dependency_map.keys() + # if After mentions an extension, remove it and instead add all of its scripts + for load_after in list(script.load_after): + if load_after not in scripts and load_after in loaded_extensions_scripts: + script.load_after.remove(load_after) + + for other_script in loaded_extensions_scripts.get(load_after, []): + script.load_after.append(other_script.script_canonical_name) + + dependencies = {} + + for script_canonical_name, script in scripts.items(): + for required_script in script.requires: + if required_script not in scripts and required_script not in loaded_extensions: + errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False) - for script_canonical_name in ordered_script: - script_data = script_dependency_map[script_canonical_name] - scripts_list.append(script_data['script_file']) + dependencies[script_canonical_name] = script.load_after - scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + ordered_scripts = topological_sort(dependencies) + scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts] return scripts_list -- cgit v1.2.3 From 8a6e4bda21dddef3ab2e70a05d71b587b6c8b04b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:00:17 +0900 Subject: catch uncaught exception with ui creation scripts prevent total webui crash --- modules/scripts.py | 54 +++++++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 25 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index b0689a23..961d032c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -570,40 +570,44 @@ class ScriptRunner: if controls is None: return - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] + try: + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] - for control in controls: - control.custom_script_source = os.path.basename(script.filename) + for control in controls: + control.custom_script_source = os.path.basename(script.filename) - arg_info = api_models.ScriptArg(label=control.label or "") + arg_info = api_models.ScriptArg(label=control.label or "") - for field in ("value", "minimum", "maximum", "step"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) + for field in ("value", "minimum", "maximum", "step"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) - choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string - if choices is not None: - arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] + choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string + if choices is not None: + arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] - api_args.append(arg_info) + api_args.append(arg_info) - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names - self.inputs += controls - script.args_to = len(self.inputs) + self.inputs += controls + script.args_to = len(self.inputs) + + except Exception: + errors.report(f"Error creating UI for {script.name}: ", exc_info=True) def setup_ui_for_section(self, section, scriptlist=None): if scriptlist is None: -- cgit v1.2.3 From ef6b8123dc57e4e4bd5e08d9f3e3dbdfdf6b4c4a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 09:57:39 +0300 Subject: put code that can cause an exception into its own function for #14120 --- modules/scripts.py | 62 +++++++++++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 29 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 961d032c..7f9454eb 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -560,54 +560,58 @@ class ScriptRunner: on_after.clear() def create_script_ui(self, script): - import modules.api.models as api_models script.args_from = len(self.inputs) script.args_to = len(self.inputs) + try: + self.create_script_ui_inner(script) + except Exception: + errors.report(f"Error creating UI for {script.name}: ", exc_info=True) + + def create_script_ui_inner(self, script): + import modules.api.models as api_models + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) if controls is None: return - try: - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - for control in controls: - control.custom_script_source = os.path.basename(script.filename) + api_args = [] - arg_info = api_models.ScriptArg(label=control.label or "") + for control in controls: + control.custom_script_source = os.path.basename(script.filename) - for field in ("value", "minimum", "maximum", "step"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) + arg_info = api_models.ScriptArg(label=control.label or "") - choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string - if choices is not None: - arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] + for field in ("value", "minimum", "maximum", "step"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) - api_args.append(arg_info) + choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string + if choices is not None: + arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) + api_args.append(arg_info) - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields - self.inputs += controls - script.args_to = len(self.inputs) + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names - except Exception: - errors.report(f"Error creating UI for {script.name}: ", exc_info=True) + self.inputs += controls + script.args_to = len(self.inputs) def setup_ui_for_section(self, section, scriptlist=None): if scriptlist is None: -- cgit v1.2.3 From ac4578912395627731f2cd8529f87a95df1f7644 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 21:16:27 -0700 Subject: Removed soft inpainting, added hooks for softpainting to work instead. --- modules/processing.py | 94 +++++++++++++++---------------------- modules/scripts.py | 70 +++++++++++++++++++++++++++ modules/sd_samplers_cfg_denoiser.py | 23 ++++----- 3 files changed, 118 insertions(+), 69 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/processing.py b/modules/processing.py index 7d46949f..5a1a90af 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,7 +30,6 @@ import modules.sd_models as sd_models import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion -import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc): return image -def apply_overlay(image, paste_loc, index, overlays): - if overlays is None or index >= len(overlays): +def apply_overlay(image, paste_loc, overlay): + if overlay is None: return image - overlay = overlays[index] - if paste_loc is not None: image = uncrop(image, (overlay.width, overlay.height), paste_loc) @@ -150,7 +147,6 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None - masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = None @@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = pp.samples + if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate adaptive masks based on pixel differences. - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_masks(soft_inpainting=p.soft_inpainting, - nmask=p.nmask, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -955,9 +937,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image + + mask_for_overlay = p.mask_for_overlay + overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) + image_without_cc = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) @@ -968,9 +959,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) - image = apply_overlay(image, p.paste_to, i, p.overlay_images) + image = apply_overlay(image, p.paste_to, overlay_image) if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) @@ -981,13 +972,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image.info["parameters"] = text output_images.append(image) - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: - mask_for_overlay = p.mask_for_overlay - elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]: - mask_for_overlay = p.masks_for_overlay[i] - else: - mask_for_overlay = None - if mask_for_overlay is not None: if opts.return_mask or opts.save_mask: image_mask = mask_for_overlay.convert('RGB') @@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - soft_inpainting: si.SoftInpaintingParameters = si.default + mask_round: bool = True inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 @@ -1447,7 +1431,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) + image_mask = create_binary_mask(image_mask, round=self.mask_round) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1465,7 +1449,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask if self.soft_inpainting is None else None + self.mask_for_overlay = image_mask mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1476,13 +1460,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) - if self.soft_inpainting is None: - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) - - self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1504,15 +1485,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - if self.soft_inpainting is not None: - # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image.convert('RGBA')) - self.masks_for_overlay.append(image_mask) - else: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1565,7 +1541,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - if self.soft_inpainting is None: + if self.mask_round: latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) @@ -1578,7 +1554,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() @@ -1589,8 +1565,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None and self.soft_inpainting is None: - samples = samples * self.nmask + self.init_latent * self.mask + blended_samples = samples * self.nmask + self.init_latent * self.mask + + if self.scripts is not None: + mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent + + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 7f9454eb..92a07c56 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, AlwaysVisible = object() +class MaskBlendArgs: + def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + self.current_latent = current_latent + self.nmask = nmask + self.init_latent = init_latent + self.mask = mask + self.blended_samples = blended_samples + + self.denoiser = denoiser + self.is_final_blend = denoiser is None + self.sigma = sigma + +class PostSampleArgs: + def __init__(self, samples): + self.samples = samples class PostprocessImageArgs: def __init__(self, image): self.image = image +class PostProcessMaskOverlayArgs: + def __init__(self, index, mask_for_overlay, overlay_image): + self.index = index + self.mask_for_overlay = mask_for_overlay + self.overlay_image = overlay_image class PostprocessBatchListArgs: def __init__(self, images): @@ -206,6 +226,25 @@ class Script: pass + def on_mask_blend(self, p, mba: MaskBlendArgs, *args): + """ + Called in inpainting mode when the original content is blended with the inpainted content. + This is called at every step in the denoising process and once at the end. + If is_final_blend is true, this is called for the final blending stage. + Otherwise, denoiser and sigma are defined and may be used to inform the procedure. + """ + + pass + + def post_sample(self, p, ps: PostSampleArgs, *args): + """ + Called after the samples have been generated, + but before they have been decoded by the VAE, if applicable. + Check getattr(samples, 'already_decoded', False) to test if the images are decoded. + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -213,6 +252,13 @@ class Script: pass + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -767,6 +813,22 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def post_sample(self, p, ps: PostSampleArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.post_sample(p, ps, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + + def on_mask_blend(self, p, mba: MaskBlendArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.on_mask_blend(p, mba, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: @@ -775,6 +837,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_maskoverlay(p, ppmo, *script_args) + except Exception: + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f13e8dcc..eb9d5daf 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -109,19 +109,16 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" # If we use masks, blending between the denoised and original latent images occurs here. - def apply_blend(latent): - if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): - return self.p.denoiser_masked_blend_function( - self, - # Using an argument dictionary so that arguments can be added without breaking extensions. - args= - { - "denoiser": self, - "current_latent": latent, - "sigma": sigma - }) - else: - return self.init_latent * self.mask + self.nmask * latent + def apply_blend(current_latent): + blended_latent = current_latent * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + return blended_latent # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: -- cgit v1.2.3 From 2abc417834d752e43a283f8603bfddfb1c80b30f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 22:25:53 -0700 Subject: Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to. --- modules/processing.py | 23 ++- modules/scripts.py | 4 +- modules/soft_inpainting.py | 308 ---------------------------------- scripts/soft_inpainting.py | 401 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 413 insertions(+), 323 deletions(-) delete mode 100644 modules/soft_inpainting.py create mode 100644 scripts/soft_inpainting.py (limited to 'modules/scripts.py') diff --git a/modules/processing.py b/modules/processing.py index 5a1a90af..f8d85bdf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim) p.scripts.post_sample(p, ps) - samples_ddim = pp.samples + samples_ddim = ps.samples if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -944,7 +943,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) p.scripts.postprocess_maskoverlay(p, ppmo) - mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: @@ -959,7 +958,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) image = apply_overlay(image, p.paste_to, overlay_image) @@ -1512,9 +1511,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size - if self.masks_for_overlay is not None: - self.masks_for_overlay = self.masks_for_overlay * self.batch_size - if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size @@ -1565,14 +1561,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - blended_samples = samples * self.nmask + self.init_latent * self.mask + if self.mask is not None: + blended_samples = samples * self.nmask + self.init_latent * self.mask - if self.scripts is not None: - mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) - self.scripts.on_mask_blend(self, mba) - blended_samples = mba.blended_latent + if self.scripts is not None: + mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent - samples = blended_samples + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 92a07c56..b6fcf96e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -12,12 +12,12 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, AlwaysVisible = object() class MaskBlendArgs: - def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None): self.current_latent = current_latent self.nmask = nmask self.init_latent = init_latent self.mask = mask - self.blended_samples = blended_samples + self.blended_latent = blended_latent self.denoiser = denoiser self.is_final_blend = denoiser is None diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py deleted file mode 100644 index b36ac8fa..00000000 --- a/modules/soft_inpainting.py +++ /dev/null @@ -1,308 +0,0 @@ -class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): - self.mask_blend_power = mask_blend_power - self.mask_blend_scale = mask_blend_scale - self.inpaint_detail_preservation = inpaint_detail_preservation - - def add_generation_params(self, dest): - dest[enabled_gen_param_label] = True - dest[gen_param_labels.mask_blend_power] = self.mask_blend_power - dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale - dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation - - -# ------------------- Methods ------------------- - - -def latent_blend(soft_inpainting, a, b, t): - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - import torch - - # NOTE: We use inplace operations wherever possible. - - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) - - one_minus_t2 = 1 - t2 - one_minus_t3 = 1 - t3 - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t2 - b_scaled = b * t2 - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled, t2, one_minus_t2 - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 - b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, t3, one_minus_t3 - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - del result_type - - return image_interp_scaled - - -def get_modified_nmask(soft_inpainting, nmask, sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - import torch - # todo: Why is sigma 2D? Both values are the same. - return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) - - -def apply_adaptive_masks( - latent_orig, - latent_processed, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc. uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - paste_to) - - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - -def apply_masks( - soft_inpainting, - nmask, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) - converted_mask = 255. * converted_mask - converted_mask = converted_mask.cpu().numpy().astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc.uncrop(converted_mask, - (width, height), - paste_to) - - for i, overlay_image in enumerate(overlay_images): - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - - -# ------------------- Constants ------------------- - - -default = SoftInpaintingSettings(1, 0.5, 4) - -enabled_ui_label = "Soft inpainting" -enabled_gen_param_label = "Soft inpainting enabled" -enabled_el_id = "soft_inpainting_enabled" - -ui_labels = SoftInpaintingSettings( - "Schedule bias", - "Preservation strength", - "Transition contrast boost") - -ui_info = SoftInpaintingSettings( - "Shifts when preservation of original content occurs during denoising.", - "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") - -gen_param_labels = SoftInpaintingSettings( - "Soft inpainting schedule bias", - "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") - -el_ids = SoftInpaintingSettings( - "mask_blend_power", - "mask_blend_scale", - "inpaint_detail_preservation") - - -# ------------------- UI ------------------- - - -def gradio_ui(): - import gradio as gr - from modules.ui_components import InputAccordion - - with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: - with gr.Group(): - gr.Markdown( - """ - Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. - **High _Mask blur_** values are recommended! - """) - - result = SoftInpaintingSettings( - gr.Slider(label=ui_labels.mask_blend_power, - info=ui_info.mask_blend_power, - minimum=0, - maximum=8, - step=0.1, - value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), - gr.Slider(label=ui_labels.mask_blend_scale, - info=ui_info.mask_blend_scale, - minimum=0, - maximum=8, - step=0.05, - value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), - gr.Slider(label=ui_labels.inpaint_detail_preservation, - info=ui_info.inpaint_detail_preservation, - minimum=1, - maximum=32, - step=0.5, - value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) - - with gr.Accordion("Help", open=False): - gr.Markdown( - f""" - ### {ui_labels.mask_blend_power} - - The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). - This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. - This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - - - **Below 1**: Stronger preservation near the end (with low sigma) - - **1**: Balanced (proportional to sigma) - - **Above 1**: Stronger preservation in the beginning (with high sigma) - """) - gr.Markdown( - f""" - ### {ui_labels.mask_blend_scale} - - Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. - This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - - - **Low values**: Favors generated content. - - **High values**: Favors original content. - """) - gr.Markdown( - f""" - ### {ui_labels.inpaint_detail_preservation} - - This parameter controls how the original latent vectors and denoised latent vectors are interpolated. - With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. - This can prevent the loss of contrast that occurs with linear interpolation. - - - **Low values**: Softer blending, details may fade. - - **High values**: Stronger contrast, may over-saturate colors. - """) - - return ( - [ - soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation - ], - [ - (soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) - ] - ) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py new file mode 100644 index 00000000..47e0269b --- /dev/null +++ b/scripts/soft_inpainting.py @@ -0,0 +1,401 @@ +import gradio as gr +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def apply_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +class Script(scripts.Script): + + def __init__(self): + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation] + + def process(self, p, enabled, power, scale, detail_preservation): + if not enabled: + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + if mba.sigma is None: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + from modules import images + from modules.shared import opts + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(soft_inpainting=settings, + nmask=p.nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file -- cgit v1.2.3 From edfae95d90a49ea95394b772817a59dde4175222 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:21:00 +0900 Subject: prevent crash due to Script __init__ exception --- modules/scripts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index b6fcf96e..3a766911 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -566,7 +566,12 @@ class ScriptRunner: auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() for script_data in auto_processing_scripts + scripts_data: - script = script_data.script_class() + try: + script = script_data.script_class() + except Exception: + errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True) + continue + script.filename = script_data.path script.is_txt2img = not is_img2img script.is_img2img = is_img2img -- cgit v1.2.3 From 00901bfbe0095303554f4440b4c12fac262e2e89 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 1 Jan 2024 15:47:57 +0900 Subject: handle selectable script_index is None --- modules/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 3a766911..017aed5a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -696,6 +696,8 @@ class ScriptRunner: self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): + if script_index is None: + script_index = 0 selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] @@ -739,7 +741,7 @@ class ScriptRunner: def run(self, p, *args): script_index = args[0] - if script_index == 0: + if script_index == 0 or script_index is None: return None script = self.selectable_scripts[script_index-1] -- cgit v1.2.3