diff options
-rw-r--r-- | modules/api/api.py | 7 | ||||
-rw-r--r-- | modules/processing.py | 4 | ||||
-rw-r--r-- | modules/safe.py | 18 |
3 files changed, 18 insertions, 11 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 1de3f98f..54ee7cb0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -152,7 +152,10 @@ class Api: ) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on - p = StableDiffusionProcessingImg2Img(**vars(populate)) + + args = vars(populate) + args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + p = StableDiffusionProcessingImg2Img(**args) imgs = [] for img in init_images: @@ -170,7 +173,7 @@ class Api: b64images = list(map(encode_pil_to_base64, processed.images)) - if (not img2imgreq.include_init_images): + if not img2imgreq.include_init_images: img2imgreq.init_images = None img2imgreq.mask = None diff --git a/modules/processing.py b/modules/processing.py index edceb532..fd995b8a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- samples_ddim = samples_ddim.to(devices.dtype_vae)
- x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
diff --git a/modules/safe.py b/modules/safe.py index a9209e38..10460ad0 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler): raise Exception(f"global '{module}/{name}' is forbidden")
-allowed_zip_names = ["archive/data.pkl", "archive/version"]
-allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
-
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
- if name in allowed_zip_names:
- continue
if allowed_zip_names_re.match(name):
continue
@@ -82,8 +80,14 @@ def check_pt(filename, extra_handler): # new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
-
- with z.open('archive/data.pkl') as file:
+
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}")
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}")
+ with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
|