diff options
author | papuSpartan <30642826+papuSpartan@users.noreply.github.com> | 2022-12-10 08:02:39 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-10 08:02:39 +0000 |
commit | 6387043fd2c3311d66690ff27d7da0e030b29cd8 (patch) | |
tree | 075559ce3e52511cdd459da2b4cc33360d6eb258 /modules/interrogate.py | |
parent | 00ebc26c4e2962a31e067539d89cd695d999539a (diff) | |
parent | 1d01404c5615debfca24f7fbe429ddd2f5b11eb9 (diff) | |
download | stable-diffusion-webui-gfx803-6387043fd2c3311d66690ff27d7da0e030b29cd8.tar.gz stable-diffusion-webui-gfx803-6387043fd2c3311d66690ff27d7da0e030b29cd8.tar.bz2 stable-diffusion-webui-gfx803-6387043fd2c3311d66690ff27d7da0e030b29cd8.zip |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r-- | modules/interrogate.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py index 9769aa34..0068b81c 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,4 +1,3 @@ -import contextlib
import os
import sys
import traceback
@@ -11,10 +10,9 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram
+from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384
-blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
@@ -47,7 +45,14 @@ class InterrogateModels: def load_blip_model(self):
import models.blip
- blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "BLIP"),
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+ ext_filter=[".pth"],
+ download_name='model_base_caption_capfilt_large.pth',
+ )
+
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
@@ -148,8 +153,7 @@ class InterrogateModels: clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
- precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
- with torch.no_grad(), precision_scope("cuda"):
+ with torch.no_grad(), devices.autocast():
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)
|