From 6ad666e4794a57dd65790dd6a259d5d4330d45ed Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:46:20 +0300 Subject: more changes for #13865: fix formatting, rename the function, add comment and add a readme entry --- modules/sd_hijack.py | 24 +++++++++++++----------- modules/sd_models.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c6d17764..fba23c38 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -184,17 +184,19 @@ class StableDiffusionModelHijack: errors.display(e, "applying cross attention optimization") undo_optimizations() - def conv_ssd(self, m): - delattr(m.model.diffusion_model.middle_block, '1') - delattr(m.model.diffusion_model.middle_block, '2') - for i in ['9','8','7','6','5','4']: - delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks,i) - delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') - delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') - devices.torch_gc() + def convert_sdxl_to_ssd(self, m): + """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)""" + + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9', '8', '7', '6', '5', '4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1') + devices.torch_gc() def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1036a3b1..841402e8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) if model.is_ssd: - sd_hijack.model_hijack.conv_ssd(model) + sd_hijack.model_hijack.convert_sdxl_to_ssd(model) if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model -- cgit v1.2.3