diff options
author | InvincibleDude <81354513+InvincibleDude@users.noreply.github.com> | 2023-02-05 15:02:44 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-05 15:02:44 +0000 |
commit | f4b78e73a424299a496801930e6d8868d8d03e61 (patch) | |
tree | 48884e8a2ba070d8640f79c1676ffff3e35f37e7 /modules/sd_disable_initialization.py | |
parent | 3ec2eb8bf12ae629c292ed0e96f199669040c5de (diff) | |
parent | ea9bd9fc7409109adcd61b897abc2c8881161256 (diff) | |
download | stable-diffusion-webui-gfx803-f4b78e73a424299a496801930e6d8868d8d03e61.tar.gz stable-diffusion-webui-gfx803-f4b78e73a424299a496801930e6d8868d8d03e61.tar.bz2 stable-diffusion-webui-gfx803-f4b78e73a424299a496801930e6d8868d8d03e61.zip |
Merge branch 'AUTOMATIC1111:master' into improved-hr-conflict-test
Diffstat (limited to 'modules/sd_disable_initialization.py')
-rw-r--r-- | modules/sd_disable_initialization.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index e90aa9fe..c4a09d15 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,8 +20,9 @@ class DisableInitialization: ```
"""
- def __init__(self):
+ def __init__(self, disable_clip=True):
self.replaced = []
+ self.disable_clip = disable_clip
def replace(self, obj, field, func):
original = getattr(obj, field, None)
@@ -75,12 +76,14 @@ class DisableInitialization: self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
+
+ if self.disable_clip:
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced:
|