diff options
-rw-r--r-- | modules/devices.py | 28 | ||||
-rw-r--r-- | modules/processing.py | 28 | ||||
-rw-r--r-- | modules/shared.py | 1 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 23 | ||||
-rw-r--r-- | modules/ui.py | 2 | ||||
-rw-r--r-- | webui.py | 7 |
6 files changed, 69 insertions, 20 deletions
diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/modules/processing.py b/modules/processing.py index 61e97077..a408d622 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -544,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = []
output_images = []
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
+ def get_conds_with_caching(function, required_prompts, steps, cache):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+ """
+
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -571,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
- with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
- c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
diff --git a/modules/shared.py b/modules/shared.py index 10231a75..f0e10b35 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -577,6 +577,7 @@ latent_upscale_modes = { "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
+ "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
}
sd_upscalers = []
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ce7e4f5d..e9cf432f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -150,19 +150,20 @@ class EmbeddingDatabase: else:
self.skipped_embeddings[name] = embedding
- for fn in os.listdir(self.embeddings_dir):
- try:
- fullfn = os.path.join(self.embeddings_dir, fn)
-
- if os.stat(fullfn).st_size == 0:
+ for root, dirs, fns in os.walk(self.embeddings_dir):
+ for fn in fns:
+ try:
+ fullfn = os.path.join(root, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
continue
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
diff --git a/modules/ui.py b/modules/ui.py index 81d96c5b..030f0685 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -550,6 +550,8 @@ Requested path was: {f} os.startfile(path)
elif platform.system() == "Darwin":
sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
else:
sp.Popen(["xdg-open", path])
@@ -4,7 +4,7 @@ import threading import time
import importlib
import signal
-import threading
+import re
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -13,6 +13,11 @@ from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
+import torch
+# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
+if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0)
+
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer
import modules.extras
|