aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-06-27 06:23:15 +0000
committerGitHub <noreply@github.com>2023-06-27 06:23:15 +0000
commit0b97ae2832a89b103be1c062d7778e558732712f (patch)
tree6d0966156f1cfa78aa0e97ed0474afaffc50f2cd /modules/sd_hijack_optimizations.py
parent7e2d39a2d158d1426321686b05d31a3ea694a99e (diff)
parent3cd4fd51ef916aba8b978490569a5da82795a839 (diff)
downloadstable-diffusion-webui-gfx803-0b97ae2832a89b103be1c062d7778e558732712f.tar.gz
stable-diffusion-webui-gfx803-0b97ae2832a89b103be1c062d7778e558732712f.tar.bz2
stable-diffusion-webui-gfx803-0b97ae2832a89b103be1c062d7778e558732712f.zip
Merge branch 'dev' into master
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 80e48a42..53e27ade 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,5 @@
from __future__ import annotations
import math
-import sys
-import traceback
import psutil
import torch
@@ -48,7 +46,7 @@ class SdOptimizationXformers(SdOptimization):
priority = 100
def is_available(self):
- return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
@@ -140,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Cannot import xformers", exc_info=True)
def get_available_vram():
@@ -605,7 +602,7 @@ def sdp_attnblock_forward(self, x):
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
+ q, k, v = q.float(), k.float(), v.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()