aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorsuperhero-7 <537093830@qq.com>2023-09-23 09:51:41 +0000
committersuperhero-7 <537093830@qq.com>2023-09-23 09:51:41 +0000
commit702a1e1cc70240f2adbcfb707a644a5a98b5443c (patch)
tree56124ff0ea213017d0c8d2f4aaddf6689578e2ac /modules/sd_hijack.py
parent5ef669de080814067961f28357256e8fe27544f4 (diff)
downloadstable-diffusion-webui-gfx803-702a1e1cc70240f2adbcfb707a644a5a98b5443c.tar.gz
stable-diffusion-webui-gfx803-702a1e1cc70240f2adbcfb707a644a5a98b5443c.tar.bz2
stable-diffusion-webui-gfx803-702a1e1cc70240f2adbcfb707a644a5a98b5443c.zip
support m18
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 592f0055..ae9b2a65 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -208,11 +208,10 @@ class StableDiffusionModelHijack:
else:
m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
-
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
@@ -258,7 +257,6 @@ class StableDiffusionModelHijack:
if hasattr(m, 'cond_stage_model'):
delattr(m, 'cond_stage_model')
-
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped