aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/cmd_args.py4
-rw-r--r--modules/esrgan_model.py2
-rw-r--r--modules/extra_networks.py14
-rw-r--r--modules/hashes.py29
-rw-r--r--modules/modelloader.py7
-rw-r--r--modules/realesrgan_model.py2
-rw-r--r--modules/scripts.py21
-rw-r--r--modules/sd_models.py1
-rw-r--r--modules/shared.py9
-rw-r--r--modules/ui.py28
-rw-r--r--modules/ui_extensions.py4
-rw-r--r--modules/ui_extra_networks.py16
-rw-r--r--modules/ui_loadsave.py4
-rw-r--r--modules/upscaler.py1
14 files changed, 94 insertions, 48 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 85db93f3..3eeb84d5 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -12,8 +12,8 @@ parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
-parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
-parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
+parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
+parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a009eb42..2fced999 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -154,7 +154,7 @@ class UpscalerESRGAN(Upscaler):
if "http" in path:
filename = load_file_from_url(
url=self.model_url,
- model_dir=self.model_path,
+ model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
progress=True,
)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index f9db41bc..34a3ba63 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -14,9 +14,23 @@ def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_default_extra_networks():
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
+ register_extra_network(ExtraNetworkHypernet())
+
+
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
+ self.positional = []
+ self.named = {}
+
+ for item in self.items:
+ parts = item.split('=', 2)
+ if len(parts) == 2:
+ self.named[parts[0]] = parts[1]
+ else:
+ self.positional.append(item)
class ExtraNetwork:
diff --git a/modules/hashes.py b/modules/hashes.py
index 032120f4..8b7ea0ac 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -46,8 +46,8 @@ def calculate_sha256(filename):
return hash_sha256.hexdigest()
-def sha256_from_cache(filename, title):
- hashes = cache("hashes")
+def sha256_from_cache(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
if title not in hashes:
@@ -62,10 +62,10 @@ def sha256_from_cache(filename, title):
return cached_sha256
-def sha256(filename, title):
- hashes = cache("hashes")
+def sha256(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
- sha256_value = sha256_from_cache(filename, title)
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
if sha256_value is not None:
return sha256_value
@@ -73,7 +73,11 @@ def sha256(filename, title):
return None
print(f"Calculating sha256 for {filename}: ", end='')
- sha256_value = calculate_sha256(filename)
+ if use_addnet_hash:
+ with open(filename, "rb") as file:
+ sha256_value = addnet_hash_safetensors(file)
+ else:
+ sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
@@ -86,6 +90,19 @@ def sha256(filename, title):
return sha256_value
+def addnet_hash_safetensors(b):
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 2a479bcb..be23071a 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -47,7 +47,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, model_path, True, download_name)
+ dl = load_file_from_url(model_url, places[0], True, download_name)
output.append(dl)
else:
output.append(model_url)
@@ -144,7 +144,10 @@ def load_upscalers():
for cls in reversed(used_classes.values()):
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
- scaler = cls(commandline_options.get(cmd_name, None))
+ commandline_model_path = commandline_options.get(cmd_name, None)
+ scaler = cls(commandline_model_path)
+ scaler.user_path = commandline_model_path
+ scaler.model_download_path = commandline_model_path or scaler.model_path
datas += scaler.scalers
shared.sd_upscalers = sorted(
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index c24d8dbb..99983678 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -73,7 +73,7 @@ class UpscalerRealESRGAN(Upscaler):
return None
if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
return info
except Exception as e:
diff --git a/modules/scripts.py b/modules/scripts.py
index e33d8c81..c902804b 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -271,6 +271,12 @@ def load_scripts():
sys.path = syspath
current_basedir = paths.script_path
+ global scripts_txt2img, scripts_img2img, scripts_postproc
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -527,9 +533,9 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
-scripts_txt2img = ScriptRunner()
-scripts_img2img = ScriptRunner()
-scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+scripts_txt2img: ScriptRunner = None
+scripts_img2img: ScriptRunner = None
+scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
scripts_current: ScriptRunner = None
@@ -539,14 +545,7 @@ def reload_script_body_only():
scripts_img2img.reload_sources(cache)
-def reload_scripts():
- global scripts_txt2img, scripts_img2img, scripts_postproc
-
- load_scripts()
-
- scripts_txt2img = ScriptRunner()
- scripts_img2img = ScriptRunner()
- scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+reload_scripts = load_scripts # compatibility alias
def add_classes_to_gradio_component(comp):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8e42bfea..b1afbaa7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -98,7 +98,6 @@ def setup_model():
if not os.path.exists(model_path):
os.makedirs(model_path)
- list_models()
enable_midas_autodownload()
diff --git a/modules/shared.py b/modules/shared.py
index 7cfbaa0c..3099d1d2 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -15,6 +15,7 @@ import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import Optional
demo = None
@@ -113,7 +114,7 @@ class State:
time_start = None
server_start = None
_server_command_signal = threading.Event()
- _server_command: str | None = None
+ _server_command: Optional[str] = None
@property
def need_restart(self) -> bool:
@@ -131,14 +132,14 @@ class State:
return self._server_command
@server_command.setter
- def server_command(self, value: str | None) -> None:
+ def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
- def wait_for_server_command(self, timeout: float | None = None) -> str | None:
+ def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
@@ -472,7 +473,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
- "js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
+ "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
diff --git a/modules/ui.py b/modules/ui.py
index 70a597d7..82820ab5 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -463,8 +463,8 @@ def create_ui():
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
@@ -526,14 +526,16 @@ def create_ui():
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- def update_resolution_hires_input(inp, evt):
- getattr(inp, evt)(
+ for component in hr_resolution_preview_inputs:
+ event = component.release if isinstance(component, gr.Slider) else component.change
+
+ event(
fn=calc_resolution_hires,
inputs=hr_resolution_preview_inputs,
outputs=[hr_final_resolution],
show_progress=False,
)
- getattr(inp, evt)(
+ event(
None,
_js="onCalcResolutionHires",
inputs=hr_resolution_preview_inputs,
@@ -541,10 +543,6 @@ def create_ui():
show_progress=False,
)
- update_resolution_hires_input(enable_hr, 'change')
- for input in hr_resolution_preview_inputs[1:]:
- update_resolution_hires_input(input, 'release')
-
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
@@ -794,8 +792,8 @@ def create_ui():
with gr.Tab(label="Resize to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
@@ -1185,8 +1183,8 @@ def create_ui():
with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_process_height")
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
@@ -1278,8 +1276,8 @@ def create_ui():
template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
- training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
- training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_training_height")
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 4ba3bdd7..ef18f438 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -125,7 +125,9 @@ def make_commit_link(commit_hash, remote, text=None):
if text is None:
text = commit_hash[:8]
if remote.startswith("https://github.com/"):
- href = os.path.join(remote, "commit", commit_hash)
+ if remote.endswith(".git"):
+ remote = remote[:-4]
+ href = remote + "/commit/" + commit_hash
return f'<a href="{href}" target="_blank">{text}</a>'
else:
return text
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 8bd0722e..19fbaae5 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -161,7 +161,7 @@ class ExtraNetworksPage:
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
- background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
+ background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
metadata_button = ""
metadata = item.get("metadata")
if metadata:
@@ -186,7 +186,8 @@ class ExtraNetworksPage:
return ""
args = {
- "style": f"'display: none; {height}{width}{background_image}'",
+ "background_image": background_image,
+ "style": f"'display: none; {height}{width}'",
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
@@ -231,10 +232,19 @@ class ExtraNetworksPage:
return None
-def intialize():
+def initialize():
extra_pages.clear()
+def register_default_pages():
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
+ register_page(ExtraNetworksPageTextualInversion())
+ register_page(ExtraNetworksPageHypernetworks())
+ register_page(ExtraNetworksPageCheckpoints())
+
+
class ExtraNetworksUi:
def __init__(self):
self.pages = None
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
index 728fec9e..0052a5cc 100644
--- a/modules/ui_loadsave.py
+++ b/modules/ui_loadsave.py
@@ -55,7 +55,7 @@ class UiLoadsave:
if field == 'value' and key not in self.component_mapping:
self.component_mapping[key] = x
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@@ -109,6 +109,8 @@ class UiLoadsave:
self.add_block(c, path)
elif x.label is not None:
self.add_component(f"{path}/{x.label}", x)
+ elif isinstance(x, gr.Button) and x.value is not None:
+ self.add_component(f"{path}/{x.value}", x)
def read_from_file(self):
with open(self.filename, "r", encoding="utf8") as file:
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 8acb6e96..7b1046d6 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -34,6 +34,7 @@ class Upscaler:
self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0
self.mod_scale = None
+ self.model_download_path = None
if self.model_path is None and self.name:
self.model_path = os.path.join(shared.models_path, self.name)