From 8acc901ba3a252dc6ab4fabcb41644cf64d1774c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 00:38:55 -0400 Subject: Newer versions of PyTorch use TypedStorage instead Pytorch 1.13 and later will rename _TypedStorage to TypedStorage, so check for TypedStorage and use _TypedStorage if it is not available. Currently this is needed so that nightly builds of PyTorch work correctly. --- modules/safe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index 4d06f2a5..05917463 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -12,6 +12,10 @@ import _codecs import zipfile +# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage + + def encode(*args): out = _codecs.encode(*args) return out @@ -20,7 +24,7 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return torch.storage._TypedStorage() + return TypedStorage() def find_class(self, module, name): if module == 'collections' and name == 'OrderedDict': -- cgit v1.2.3 From 66b7d7584f0b44ce1316425808c27ca7df38293c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 17:03:00 +0300 Subject: become even stricter with pickles no pickle shall pass thank you again, RyotaK --- modules/safe.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index 05917463..20be16a5 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -10,6 +10,7 @@ import torch import numpy import _codecs import zipfile +import re # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage @@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler): raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") +allowed_zip_names = ["archive/data.pkl", "archive/version"] +allowed_zip_names_re = re.compile(r"^archive/data/\d+$") + + +def check_zip_filenames(filename, names): + for name in names: + if name in allowed_zip_names: + continue + if allowed_zip_names_re.match(name): + continue + + raise Exception(f"bad file inside {filename}: {name}") + + def check_pt(filename): try: # new pytorch format is a zip file with zipfile.ZipFile(filename) as z: + check_zip_filenames(filename, z.namelist()) + with z.open('archive/data.pkl') as file: unpickler = RestrictedUnpickler(file) unpickler.load() -- cgit v1.2.3