diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-04 15:57:14 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-04 15:57:14 +0000 |
commit | 32547f2721c92794779e6ff9fb325243d5857cae (patch) | |
tree | d4d5f1a9705e59eef5029cc1be3bff57fbd389c2 /modules/safe.py | |
parent | fe6e2362e8fa5d739de6997ab155a26686d20a49 (diff) | |
parent | 3dae545a03f5102ba5d9c3f27bb6241824c5a916 (diff) | |
download | stable-diffusion-webui-gfx803-32547f2721c92794779e6ff9fb325243d5857cae.tar.gz stable-diffusion-webui-gfx803-32547f2721c92794779e6ff9fb325243d5857cae.tar.bz2 stable-diffusion-webui-gfx803-32547f2721c92794779e6ff9fb325243d5857cae.zip |
Merge branch 'master' into xygrid_infotext_improvements
Diffstat (limited to 'modules/safe.py')
-rw-r--r-- | modules/safe.py | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/modules/safe.py b/modules/safe.py new file mode 100644 index 00000000..82d44be3 --- /dev/null +++ b/modules/safe.py @@ -0,0 +1,192 @@ +# this code is adapted from the script contributed by anon from /h/
+
+import io
+import pickle
+import collections
+import sys
+import traceback
+
+import torch
+import numpy
+import _codecs
+import zipfile
+import re
+
+
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
+
+
+def encode(*args):
+ out = _codecs.encode(*args)
+ return out
+
+
+class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
+ def persistent_load(self, saved_id):
+ assert saved_id[0] == 'storage'
+ return TypedStorage()
+
+ def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
+ if module == 'collections' and name == 'OrderedDict':
+ return getattr(collections, name)
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
+ return getattr(torch._utils, name)
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
+ return getattr(torch, name)
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
+ return getattr(torch.nn.modules.container, name)
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+ return getattr(numpy.core.multiarray, name)
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
+ return getattr(numpy, name)
+ if module == '_codecs' and name == 'encode':
+ return encode
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
+ import pytorch_lightning.callbacks
+ return pytorch_lightning.callbacks.model_checkpoint
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
+ import pytorch_lightning.callbacks.model_checkpoint
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
+ if module == "__builtin__" and name == 'set':
+ return set
+
+ # Forbid everything else.
+ raise Exception(f"global '{module}/{name}' is forbidden")
+
+
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
+
+def check_zip_filenames(filename, names):
+ for name in names:
+ if allowed_zip_names_re.match(name):
+ continue
+
+ raise Exception(f"bad file inside {filename}: {name}")
+
+
+def check_pt(filename, extra_handler):
+ try:
+
+ # new pytorch format is a zip file
+ with zipfile.ZipFile(filename) as z:
+ check_zip_filenames(filename, z.namelist())
+
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}")
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}")
+ with z.open(data_pkl_filenames[0]) as file:
+ unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
+ unpickler.load()
+
+ except zipfile.BadZipfile:
+
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ with open(filename, "rb") as file:
+ unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
+ for i in range(5):
+ unpickler.load()
+
+
+def load(filename, *args, **kwargs):
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+ """
+ this function is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
+ from modules import shared
+
+ try:
+ if not shared.cmd_opts.disable_safe_unpickle:
+ check_pt(filename, extra_handler)
+
+ except pickle.UnpicklingError:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ return None
+
+ except Exception:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ return None
+
+ return unsafe_torch_load(filename, *args, **kwargs)
+
+
+class Extra:
+ """
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+ (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+ if module == 'torch' and name in ['float64', 'float16']:
+ return getattr(torch, name)
+
+ return None
+
+with safe.Extra(handler):
+ x = torch.load('model.pt')
+```
+ """
+
+ def __init__(self, handler):
+ self.handler = handler
+
+ def __enter__(self):
+ global global_extra_handler
+
+ assert global_extra_handler is None, 'already inside an Extra() block'
+ global_extra_handler = self.handler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ global global_extra_handler
+
+ global_extra_handler = None
+
+
+unsafe_torch_load = torch.load
+torch.load = load
+global_extra_handler = None
+
|