diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-16 13:53:13 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-16 13:53:13 +0000 |
commit | cd12256575dcce325519ef674323d953fbce252c (patch) | |
tree | 61cb8998f8971cbdbd3302ff1b7edda6e2737b99 /modules/sd_hijack.py | |
parent | c3facab495e6bb29b5e0b16d064b44851a733a95 (diff) | |
parent | 5e80d9ee99c5899e5e2b130408ffb65a0585a62a (diff) | |
download | stable-diffusion-webui-gfx803-cd12256575dcce325519ef674323d953fbce252c.tar.gz stable-diffusion-webui-gfx803-cd12256575dcce325519ef674323d953fbce252c.tar.bz2 stable-diffusion-webui-gfx803-cd12256575dcce325519ef674323d953fbce252c.zip |
Merge branch 'dev' into test-fp8
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bc5fbcd3..0157e19f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -184,6 +184,20 @@ class StableDiffusionModelHijack: errors.display(e, "applying cross attention optimization")
undo_optimizations()
+ def convert_sdxl_to_ssd(self, m):
+ """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
+
+ delattr(m.model.diffusion_model.middle_block, '1')
+ delattr(m.model.diffusion_model.middle_block, '2')
+ for i in ['9', '8', '7', '6', '5', '4']:
+ delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
+ delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
+ devices.torch_gc()
+
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
@@ -242,8 +256,12 @@ class StableDiffusionModelHijack: self.layers = flatten(m)
+ import modules.models.diffusion.ddpm_edit
+
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward
else:
|