diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..29c8b561 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -59,6 +59,10 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
+def fix_checkpoint():
+ ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack:
fixes = None
@@ -78,6 +82,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model
apply_optimizations()
+ fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
+
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
|