aboutsummaryrefslogtreecommitdiffstats
path: root/modules/safe.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-12-25 06:03:56 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-12-25 06:03:56 +0000
commit8eef9d8e782aa0655241e43f67059aa7bef3bdca (patch)
tree4c7e19a33c9609394733a09b80603f9ac46819db /modules/safe.py
parentc5bdba2089dc7060be2631bcbc83313b6358cbf2 (diff)
downloadstable-diffusion-webui-gfx803-8eef9d8e782aa0655241e43f67059aa7bef3bdca.tar.gz
stable-diffusion-webui-gfx803-8eef9d8e782aa0655241e43f67059aa7bef3bdca.tar.bz2
stable-diffusion-webui-gfx803-8eef9d8e782aa0655241e43f67059aa7bef3bdca.zip
a way to add an exception to unpickler without explicitly calling load_with_extra
Diffstat (limited to 'modules/safe.py')
-rw-r--r--modules/safe.py39
1 files changed, 38 insertions, 1 deletions
diff --git a/modules/safe.py b/modules/safe.py
index 479c8b86..ec23a53c 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -103,7 +103,7 @@ def check_pt(filename, extra_handler):
def load(filename, *args, **kwargs):
- return load_with_extra(filename, *args, **kwargs)
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
@@ -151,5 +151,42 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
return unsafe_torch_load(filename, *args, **kwargs)
+class Extra:
+ """
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+ (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+ if module == 'torch' and name in ['float64', 'float16']:
+ return getattr(torch, name)
+
+ return None
+
+with safe.Extra(handler):
+ x = torch.load('model.pt')
+```
+ """
+
+ def __init__(self, handler):
+ self.handler = handler
+
+ def __enter__(self):
+ global global_extra_handler
+
+ assert global_extra_handler is None, 'already inside an Extra() block'
+ global_extra_handler = self.handler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ global global_extra_handler
+
+ global_extra_handler = None
+
+
unsafe_torch_load = torch.load
torch.load = load
+global_extra_handler = None
+