diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-12-03 15:45:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-12-03 15:45:51 +0000 |
commit | 4b0dc206edbad90affe609ac0bf2e9be7e197674 (patch) | |
tree | 0c5e09982546dfce81530b3facbc517fe061b85d /modules/interrogate.py | |
parent | 2a649154ec994063b27a6723afa40e52be219771 (diff) | |
download | stable-diffusion-webui-gfx803-4b0dc206edbad90affe609ac0bf2e9be7e197674.tar.gz stable-diffusion-webui-gfx803-4b0dc206edbad90affe609ac0bf2e9be7e197674.tar.bz2 stable-diffusion-webui-gfx803-4b0dc206edbad90affe609ac0bf2e9be7e197674.zip |
use modelloader for #4956
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r-- | modules/interrogate.py | 22 |
1 files changed, 8 insertions, 14 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py index 3a09b366..0068b81c 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,4 +1,3 @@ -import contextlib
import os
import sys
import traceback
@@ -11,12 +10,9 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram
+from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384
-blip_local_dir = os.path.join('models', 'Interrogator')
-blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
-blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
@@ -49,16 +45,14 @@ class InterrogateModels: def load_blip_model(self):
import models.blip
- if not os.path.isfile(blip_local_file):
- if not os.path.isdir(blip_local_dir):
- os.mkdir(blip_local_dir)
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "BLIP"),
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+ ext_filter=[".pth"],
+ download_name='model_base_caption_capfilt_large.pth',
+ )
- print("Downloading BLIP...")
- from requests import get as reqget
- open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
- print("BLIP downloaded to", blip_local_file + '.')
-
- blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
|