diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 83 | ||||
-rw-r--r-- | modules/api/models.py | 2 | ||||
-rw-r--r-- | modules/mac_specific.py | 3 | ||||
-rw-r--r-- | modules/memmon.py | 12 | ||||
-rw-r--r-- | modules/models/diffusion/uni_pc/uni_pc.py | 2 | ||||
-rw-r--r-- | modules/sd_models.py | 24 | ||||
-rw-r--r-- | modules/sd_vae_approx.py | 5 | ||||
-rw-r--r-- | modules/shared.py | 1 | ||||
-rw-r--r-- | modules/timer.py | 3 | ||||
-rw-r--r-- | modules/ui.py | 3 | ||||
-rw-r--r-- | modules/ui_extensions.py | 2 | ||||
-rw-r--r-- | modules/ui_extra_networks.py | 13 |
12 files changed, 124 insertions, 29 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 376f7f04..35e17afc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -164,14 +164,10 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def get_script(self, script_name, script_runner): - if script_name is None: + def get_selectable_script(self, script_name, script_runner): + if script_name is None or script_name == "": return None, None - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx @@ -182,8 +178,49 @@ class Api: return ScriptsList(txt2img = t2ilist, img2img = i2ilist) + def get_script(self, script_name, script_runner): + if script_name is None or script_name == "": + return None, None + + script_idx = script_name_to_index(script_name, script_runner.scripts) + return script_runner.scripts[script_idx] + + def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner): + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 + for script in script_runner.scripts: + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere except position 0 to initialize script args + script_args = [None]*last_arg_index + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() + if selectable_scripts: + script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args + script_args[0] = selectable_idx + 1 + else: + # when [0] = 0 no selectable script to run + script_args[0] = 0 + + # Now check for always on scripts + if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + for alwayson_script_name in request.alwayson_scripts.keys(): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + # always on script with no arg should always run so you don't really need to add them to the requests + if "args" in request.alwayson_scripts[alwayson_script_name]: + script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"] + return script_args + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) + script_runner = scripts.scripts_txt2img + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -195,20 +232,26 @@ class Api: args = vars(populate) args.pop('script_name', None) + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them + args.pop('alwayson_scripts', None) + + script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner) send_images = args.pop('send_images', True) args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) + p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples shared.state.begin() - if script is not None: - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) + if selectable_scripts != None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() @@ -221,12 +264,16 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) - mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": not img2imgreq.save_images, @@ -239,6 +286,10 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them + args.pop('alwayson_scripts', None) + + script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner) send_images = args.pop('send_images', True) args.pop('save_images', None) @@ -246,14 +297,16 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples shared.state.begin() - if script is not None: - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) + if selectable_scripts != None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() diff --git a/modules/api/models.py b/modules/api/models.py index fa1c40df..4a70f440 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -106,6 +106,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, + {"key": "alwayson_scripts", "type": dict, "default": {}}, ] ).generate_model() @@ -122,6 +123,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, + {"key": "alwayson_scripts", "type": dict, "default": {}}, ] ).generate_model() diff --git a/modules/mac_specific.py b/modules/mac_specific.py index ddcea53b..18e6ff72 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -45,7 +45,6 @@ if has_mps: CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) diff --git a/modules/memmon.py b/modules/memmon.py index a7060f58..4018edcc 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): self.data = defaultdict(int) try: - torch.cuda.mem_get_info() + self.cuda_mem_get_info() torch.cuda.memory_stats(self.device) except Exception as e: # AMD or whatever print(f"Warning: caught exception '{e}', memory monitor disabled") self.disabled = True + def cuda_mem_get_info(self): + index = self.device.index if self.device.index is not None else torch.cuda.current_device() + return torch.cuda.mem_get_info(index) + def run(self): if self.disabled: return @@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): self.run_flag.clear() continue - self.data["min_free"] = torch.cuda.mem_get_info()[0] + self.data["min_free"] = self.cuda_mem_get_info()[0] while self.run_flag.is_set(): - free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + free, total = self.cuda_mem_get_info() self.data["min_free"] = min(self.data["min_free"], free) time.sleep(1 / self.opts.memmon_poll_rate) @@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): def read(self): if not self.disabled: - free, total = torch.cuda.mem_get_info() + free, total = self.cuda_mem_get_info() self.data["free"] = free self.data["total"] = total diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index df63d1bc..e9a093a2 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -719,7 +719,7 @@ class UniPC: x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 ) if x_t is None: diff --git a/modules/sd_models.py b/modules/sd_models.py index 93959f55..5f57ec0c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd
+def read_metadata_from_safetensors(filename):
+ import json
+
+ with open(filename, mode="rb") as file:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+
+ res = {}
+ for k, v in json_obj.get("__metadata__", {}).items():
+ res[k] = v
+ if isinstance(v, str) and v[0] == '{':
+ try:
+ res[k] = json.loads(v)
+ except Exception as e:
+ pass
+
+ return res
+
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 0027343a..e2f00468 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -35,8 +35,11 @@ def model(): global sd_vae_approx_model
if sd_vae_approx_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
+ sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
diff --git a/modules/shared.py b/modules/shared.py index 2fb9e3b5..f28a12cc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -714,6 +714,7 @@ class TotalTQDM: def clear(self):
if self._tqdm is not None:
+ self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
diff --git a/modules/timer.py b/modules/timer.py index 57a4f17a..ba92be33 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -33,3 +33,6 @@ class Timer: res += ")"
return res
+
+ def reset(self):
+ self.__init__()
diff --git a/modules/ui.py b/modules/ui.py index 621ae952..7e603332 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1751,7 +1751,8 @@ def create_ui(): def reload_javascript():
- head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
+ script_js = os.path.join(script_path, "script.js")
+ head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index bd4308ef..df75a925 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -304,7 +304,7 @@ def create_ui(): with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 85f0af4c..cdfd6f2a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -30,8 +30,8 @@ def add_pages_to_demo(app): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+ if ext not in (".png", ".jpg", ".webp"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -124,6 +124,12 @@ class ExtraNetworksPage: if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ metadata_button = ""
+ metadata = item.get("metadata")
+ if metadata:
+ metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
+ metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
+
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item.get("prompt", None),
@@ -134,6 +140,7 @@ class ExtraNetworksPage: "card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
+ "metadata_button": metadata_button,
}
return self.card_page.format(**args)
@@ -213,7 +220,6 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
- button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
@@ -224,7 +230,6 @@ def create_ui(container, button, tabname): state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
- button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh():
res = []
|