From 610a7f4e1480c0ffeedb2a07dc27ae86bf03c3a8 Mon Sep 17 00:00:00 2001 From: Edouard Leurent Date: Sat, 8 Oct 2022 16:49:43 +0100 Subject: Break after finding the local directory of stable diffusion Otherwise, we may override it with one of the next two path (. or ..) if it is present there, and then the local paths of other modules (taming transformers, codeformers, etc.) wont be found in sd_path/../. Fix https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/1085 --- modules/paths.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/paths.py') diff --git a/modules/paths.py b/modules/paths.py index 606f7d66..0519caa0 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -12,6 +12,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), for possible_sd_path in possible_sd_paths: if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): sd_path = os.path.abspath(possible_sd_path) + break assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) -- cgit v1.2.3 From 875ddfeecfaffad9eee24813301637cba310337d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 17:58:43 +0300 Subject: added guard for torch.load to prevent loading pickles with unknown content --- modules/paths.py | 1 + modules/safe.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 91 insertions(+) create mode 100644 modules/safe.py (limited to 'modules/paths.py') diff --git a/modules/paths.py b/modules/paths.py index 0519caa0..1e7a2fbc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,6 +1,7 @@ import argparse import os import sys +import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) models_path = os.path.join(script_path, "models") diff --git a/modules/safe.py b/modules/safe.py new file mode 100644 index 00000000..2d2c1371 --- /dev/null +++ b/modules/safe.py @@ -0,0 +1,89 @@ +# this code is adapted from the script contributed by anon from /h/ + +import io +import pickle +import collections +import sys +import traceback + +import torch +import numpy +import _codecs +import zipfile + + +def encode(*args): + out = _codecs.encode(*args) + return out + + +class RestrictedUnpickler(pickle.Unpickler): + def persistent_load(self, saved_id): + assert saved_id[0] == 'storage' + return torch.storage._TypedStorage() + + def find_class(self, module, name): + if module == 'collections' and name == 'OrderedDict': + return getattr(collections, name) + if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: + return getattr(torch._utils, name) + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']: + return getattr(torch, name) + if module == 'torch.nn.modules.container' and name in ['ParameterDict']: + return getattr(torch.nn.modules.container, name) + if module == 'numpy.core.multiarray' and name == 'scalar': + return numpy.core.multiarray.scalar + if module == 'numpy' and name == 'dtype': + return numpy.dtype + if module == '_codecs' and name == 'encode': + return encode + if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.model_checkpoint + if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks.model_checkpoint + return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "__builtin__" and name == 'set': + return set + + # Forbid everything else. + raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") + + +def check_pt(filename): + try: + + # new pytorch format is a zip file + with zipfile.ZipFile(filename) as z: + with z.open('archive/data.pkl') as file: + unpickler = RestrictedUnpickler(file) + unpickler.load() + + except zipfile.BadZipfile: + + # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle + with open(filename, "rb") as file: + unpickler = RestrictedUnpickler(file) + for i in range(5): + unpickler.load() + + +def load(filename, *args, **kwargs): + from modules import shared + + try: + if not shared.cmd_opts.disable_safe_unpickle: + check_pt(filename) + + except Exception: + print(f"Error verifying pickled file from {filename}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) + return None + + return unsafe_torch_load(filename, *args, **kwargs) + + +unsafe_torch_load = torch.load +torch.load = load diff --git a/modules/shared.py b/modules/shared.py index 6ecc2503..3d7f08e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -65,6 +65,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) cmd_opts = parser.parse_args() -- cgit v1.2.3