aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-06 04:51:45 +0000
committerGitHub <noreply@github.com>2023-01-06 04:51:45 +0000
commit85fa4eacea859ef2c78d719f36b25a2e19dd0be1 (patch)
tree1597ed6d494ffb1efd95023f9283be1f9f473b82
parent3ea354f274fe91bb224fe8bbe72ae215ac6622cf (diff)
parent8111b5569d07c7ac3b695e28171aede728b4ae56 (diff)
downloadstable-diffusion-webui-gfx803-85fa4eacea859ef2c78d719f36b25a2e19dd0be1.tar.gz
stable-diffusion-webui-gfx803-85fa4eacea859ef2c78d719f36b25a2e19dd0be1.tar.bz2
stable-diffusion-webui-gfx803-85fa4eacea859ef2c78d719f36b25a2e19dd0be1.zip
Merge pull request #6402 from brkirch/work-with-nightly-local-builds
Add support for using PyTorch nightly and local builds
-rw-r--r--modules/devices.py28
-rw-r--r--webui.py7
2 files changed, 29 insertions, 6 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 800510b7..caeb0276 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
return orig_tensor_numpy(self, *args, **kwargs)
-# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
-if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
- torch.Tensor.numpy = numpy_fix
+# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
+orig_cumsum = torch.cumsum
+orig_Tensor_cumsum = torch.Tensor.cumsum
+def cumsum_fix(input, cumsum_func, *args, **kwargs):
+ if input.device.type == 'mps':
+ output_dtype = kwargs.get('dtype', input.dtype)
+ if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+ return cumsum_func(input, *args, **kwargs)
+
+
+if has_mps():
+ if version.parse(torch.__version__) < version.parse("1.13"):
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+ torch.Tensor.to = tensor_to_fix
+ torch.nn.functional.layer_norm = layer_norm_fix
+ torch.Tensor.numpy = numpy_fix
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
+ if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
+ torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
+ torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
+ orig_narrow = torch.narrow
+ torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
diff --git a/webui.py b/webui.py
index d89e0fb5..ff6eb6eb 100644
--- a/webui.py
+++ b/webui.py
@@ -4,7 +4,7 @@ import threading
import time
import importlib
import signal
-import threading
+import re
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -13,6 +13,11 @@ from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
+import torch
+# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
+if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0)
+
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer
import modules.extras