diff options
author | Greg Fuller <gfuller23@gmail.com> | 2022-10-12 19:44:41 +0000 |
---|---|---|
committer | Greg Fuller <gfuller23@gmail.com> | 2022-10-12 19:44:41 +0000 |
commit | fb3cefb348600497d964c4fd3f99138d7dde0ec5 (patch) | |
tree | 984c02da13dae6ab0d6c21304a551bbfffea02f2 /modules/deepbooru.py | |
parent | d717eb079cd6b7fa7a4f97c0a10d400bdec753fb (diff) | |
parent | 698d303b04e293635bfb49c525409f3bcf671dce (diff) | |
download | stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.tar.gz stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.tar.bz2 stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.zip |
Merge remote-tracking branch 'upstream/master' into interrogate_include_ranks_in_output
Diffstat (limited to 'modules/deepbooru.py')
-rw-r--r-- | modules/deepbooru.py | 136 |
1 files changed, 114 insertions, 22 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 32d741e2..419e6a9c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,21 +1,99 @@ import os.path from concurrent.futures import ProcessPoolExecutor -from multiprocessing import get_context +import multiprocessing +import time +import re +re_special = re.compile(r'([\\()])') -def _load_tf_and_return_tags(pil_image, threshold, include_ranks): +def get_deepbooru_tags(pil_image): + """ + This method is for running only one image at a time for simple use. Used to the img2img interrogate. + """ + from modules import shared # prevents circular reference + + try: + create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) + return get_tags_from_process(pil_image) + finally: + release_process() + + +def create_deepbooru_opts(): + from modules import shared + + return { + "use_spaces": shared.opts.deepbooru_use_spaces, + "use_escape": shared.opts.deepbooru_escape, + "alpha_sort": shared.opts.deepbooru_sort_alpha, + } + + +def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts): + model, tags = get_deepbooru_tags_model() + while True: # while process is running, keep monitoring queue for new image + pil_image = queue.get() + if pil_image == "QUIT": + break + else: + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts) + + +def create_deepbooru_process(threshold, deepbooru_opts): + """ + Creates deepbooru process. A queue is created to send images into the process. This enables multiple images + to be processed in a row without reloading the model or creating a new process. To return the data, a shared + dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned + to the dictionary and the method adding the image to the queue should wait for this value to be updated with + the tags. + """ + from modules import shared # prevents circular reference + shared.deepbooru_process_manager = multiprocessing.Manager() + shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() + shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) + shared.deepbooru_process.start() + + +def get_tags_from_process(image): + from modules import shared + + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process_queue.put(image) + while shared.deepbooru_process_return["value"] == -1: + time.sleep(0.2) + caption = shared.deepbooru_process_return["value"] + shared.deepbooru_process_return["value"] = -1 + + return caption + + +def release_process(): + """ + Stops the deepbooru process to return used memory + """ + from modules import shared # prevents circular reference + shared.deepbooru_process_queue.put("QUIT") + shared.deepbooru_process.join() + shared.deepbooru_process_queue = None + shared.deepbooru_process = None + shared.deepbooru_process_return = None + shared.deepbooru_process_manager = None + +def get_deepbooru_tags_model(): import deepdanbooru as dd import tensorflow as tf import numpy as np - this_folder = os.path.dirname(__file__) model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) if not os.path.exists(os.path.join(model_path, 'project.json')): # there is no point importing these every time import zipfile from basicsr.utils.download_util import load_file_from_url - load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", - model_path) + load_file_from_url( + r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", + model_path) with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: zip_ref.extractall(model_path) os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) @@ -24,6 +102,17 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks): model = dd.project.load_model_from_project( model_path, compile_model=True ) + return model, tags + + +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts): + import deepdanbooru as dd + import tensorflow as tf + import numpy as np + + alpha_sort = deepbooru_opts['alpha_sort'] + use_spaces = deepbooru_opts['use_spaces'] + use_escape = deepbooru_opts['use_escape'] width = model.input_shape[2] height = model.input_shape[1] @@ -46,32 +135,35 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks): for i, tag in enumerate(tags): result_dict[tag] = y[i] - result_tags_out = [] + + unsorted_tags_in_theshold = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - tag_formatted = tag.replace('_', ' ').replace(':', ' ') - if include_ranks: - result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})') - else: - result_tags_out.append(tag_formatted) + unsorted_tags_in_theshold.append((result_dict[tag], tag)) result_tags_print.append(f'{result_dict[tag]} {tag}') - print('\n'.join(sorted(result_tags_print, reverse=True))) + # sort tags + result_tags_out = [] + sort_ndx = 0 + if alpha_sort: + sort_ndx = 1 + + # sort by reverse by likelihood and normal for alpha + unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) + for weight, tag in unsorted_tags_in_theshold: + result_tags_out.append(tag) - return ', '.join(result_tags_out) + print('\n'.join(sorted(result_tags_print, reverse=True))) + tags_text = ', '.join(result_tags_out) -def subprocess_init_no_cuda(): - import os - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + if use_spaces: + tags_text = tags_text.replace('_', ' ') + if use_escape: + tags_text = re.sub(re_special, r'\\\1', tags_text) -def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False): - context = get_context('spawn') - with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks) - ret = f.result() # will rethrow any exceptions - return ret
\ No newline at end of file + return tags_text.replace(':', ' ') |