aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py28
1 files changed, 24 insertions, 4 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bc5fbcd3..e139d996 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
-ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
-sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
+ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
+
+sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
+
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
@@ -184,6 +188,20 @@ class StableDiffusionModelHijack:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
+ def convert_sdxl_to_ssd(self, m):
+ """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
+
+ delattr(m.model.diffusion_model.middle_block, '1')
+ delattr(m.model.diffusion_model.middle_block, '2')
+ for i in ['9', '8', '7', '6', '5', '4']:
+ delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
+ delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
+ devices.torch_gc()
+
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
@@ -242,8 +260,12 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ import modules.models.diffusion.ddpm_edit
+
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward
else:
@@ -285,8 +307,6 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
- sd_unet.original_forward = None
-
def apply_circular(self, enable):
if self.circular_enabled == enable: