aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_inpainting.py
diff options
context:
space:
mode:
authorunknown <mcgpapu@gmail.com>2022-12-25 08:03:55 +0000
committerunknown <mcgpapu@gmail.com>2022-12-25 08:03:55 +0000
commit876da1259965130603f2a7fea505cfa0fce09e2e (patch)
treeccb8b89d64480a4bd224b311702ffeb13b8fe754 /modules/sd_hijack_inpainting.py
parentd6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6 (diff)
parentc6f347b81f584b6c0d44af7a209983284dbb52d2 (diff)
downloadstable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.tar.gz
stable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.tar.bz2
stable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.zip
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui
Diffstat (limited to 'modules/sd_hijack_inpainting.py')
-rw-r--r--modules/sd_hijack_inpainting.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 938f9a58..bb5499b3 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,3 +1,4 @@
+import os
import torch
from einops import repeat
@@ -209,7 +210,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
-
+
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
@@ -278,7 +279,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
-
+
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
@@ -319,17 +320,18 @@ class LatentInpaintDiffusion(LatentDiffusion):
def should_hijack_inpainting(checkpoint_info):
- return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+ ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
+ cfg_basename = os.path.basename(checkpoint_info.config).lower()
+ return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack():
# most of this stuff seems to no longer be needed because it is already included into SD2.0
- # LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
- # this file should be cleaned up later if weverything tuens out to work fine
+ # this file should be cleaned up later if everything turns out to work fine
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
- ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+ # ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim