diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/modelloader.py | 1 | ||||
-rw-r--r-- | modules/safe.py | 18 |
2 files changed, 12 insertions, 7 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index e4a6f8ac..7d2f0ade 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,6 +82,7 @@ def cleanup_models(): src_path = models_path dest_path = os.path.join(models_path, "Stable-diffusion") move_files(src_path, dest_path, ".ckpt") + move_files(src_path, dest_path, ".safetensors") src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) diff --git a/modules/safe.py b/modules/safe.py index a9209e38..10460ad0 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler): raise Exception(f"global '{module}/{name}' is forbidden")
-allowed_zip_names = ["archive/data.pkl", "archive/version"]
-allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
-
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
- if name in allowed_zip_names:
- continue
if allowed_zip_names_re.match(name):
continue
@@ -82,8 +80,14 @@ def check_pt(filename, extra_handler): # new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
-
- with z.open('archive/data.pkl') as file:
+
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}")
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}")
+ with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
|