From 668d7e9b9aba1770beae48a8664e0351fcd59f31 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 5 Feb 2023 11:20:47 +0300 Subject: make it possible to load SD1 checkpoints without CLIP --- modules/sd_disable_initialization.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'modules/sd_disable_initialization.py') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index e90aa9fe..c4a09d15 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,8 +20,9 @@ class DisableInitialization: ``` """ - def __init__(self): + def __init__(self, disable_clip=True): self.replaced = [] + self.disable_clip = disable_clip def replace(self, obj, field, func): original = getattr(obj, field, None) @@ -75,12 +76,14 @@ class DisableInitialization: self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) - self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) - self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) - self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) - self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) - self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) - self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) + + if self.disable_clip: + self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) + self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) + self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) + self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) + self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) + self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): for obj, field, original in self.replaced: -- cgit v1.2.3