diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-10-12 05:35:27 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-12 05:35:27 +0000 |
commit | 2e2d45b281cf3bafbf5034417dab34582a56787b (patch) | |
tree | 691b7aa7e02899f57b51c8818d4ebc45c41db7e6 /modules/deepbooru.py | |
parent | fec2221eeaafb50afd26ba3e109bf6f928011e69 (diff) | |
parent | f53f703aebc801c4204182d52bb1e0bef9808e1f (diff) | |
download | stable-diffusion-webui-gfx803-2e2d45b281cf3bafbf5034417dab34582a56787b.tar.gz stable-diffusion-webui-gfx803-2e2d45b281cf3bafbf5034417dab34582a56787b.tar.bz2 stable-diffusion-webui-gfx803-2e2d45b281cf3bafbf5034417dab34582a56787b.zip |
Merge pull request #2143 from JC-Array/deepdanbooru_pre_process
deepbooru tags for textual inversion preproccessing
Diffstat (limited to 'modules/deepbooru.py')
-rw-r--r-- | modules/deepbooru.py | 99 |
1 files changed, 79 insertions, 20 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 7e3c0618..29529949 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,21 +1,75 @@ import os.path from concurrent.futures import ProcessPoolExecutor -from multiprocessing import get_context +import multiprocessing +import time +def get_deepbooru_tags(pil_image): + """ + This method is for running only one image at a time for simple use. Used to the img2img interrogate. + """ + from modules import shared # prevents circular reference + create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha) + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process_queue.put(pil_image) + while shared.deepbooru_process_return["value"] == -1: + time.sleep(0.2) + tags = shared.deepbooru_process_return["value"] + release_process() + return tags -def _load_tf_and_return_tags(pil_image, threshold): + +def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): + model, tags = get_deepbooru_tags_model() + while True: # while process is running, keep monitoring queue for new image + pil_image = queue.get() + if pil_image == "QUIT": + break + else: + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) + + +def create_deepbooru_process(threshold, alpha_sort): + """ + Creates deepbooru process. A queue is created to send images into the process. This enables multiple images + to be processed in a row without reloading the model or creating a new process. To return the data, a shared + dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned + to the dictionary and the method adding the image to the queue should wait for this value to be updated with + the tags. + """ + from modules import shared # prevents circular reference + shared.deepbooru_process_manager = multiprocessing.Manager() + shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() + shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) + shared.deepbooru_process.start() + + +def release_process(): + """ + Stops the deepbooru process to return used memory + """ + from modules import shared # prevents circular reference + shared.deepbooru_process_queue.put("QUIT") + shared.deepbooru_process.join() + shared.deepbooru_process_queue = None + shared.deepbooru_process = None + shared.deepbooru_process_return = None + shared.deepbooru_process_manager = None + +def get_deepbooru_tags_model(): import deepdanbooru as dd import tensorflow as tf import numpy as np - this_folder = os.path.dirname(__file__) model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) if not os.path.exists(os.path.join(model_path, 'project.json')): # there is no point importing these every time import zipfile from basicsr.utils.download_util import load_file_from_url - load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", - model_path) + load_file_from_url( + r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", + model_path) with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: zip_ref.extractall(model_path) os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) @@ -24,7 +78,13 @@ def _load_tf_and_return_tags(pil_image, threshold): model = dd.project.load_model_from_project( model_path, compile_model=True ) + return model, tags + +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): + import deepdanbooru as dd + import tensorflow as tf + import numpy as np width = model.input_shape[2] height = model.input_shape[1] image = np.array(pil_image) @@ -46,28 +106,27 @@ def _load_tf_and_return_tags(pil_image, threshold): for i, tag in enumerate(tags): result_dict[tag] = y[i] - result_tags_out = [] + + unsorted_tags_in_theshold = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - result_tags_out.append(tag) + unsorted_tags_in_theshold.append((result_dict[tag], tag)) result_tags_print.append(f'{result_dict[tag]} {tag}') - print('\n'.join(sorted(result_tags_print, reverse=True))) - - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') - + # sort tags + result_tags_out = [] + sort_ndx = 0 + if alpha_sort: + sort_ndx = 1 -def subprocess_init_no_cuda(): - import os - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + # sort by reverse by likelihood and normal for alpha + unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) + for weight, tag in unsorted_tags_in_theshold: + result_tags_out.append(tag) + print('\n'.join(sorted(result_tags_print, reverse=True))) -def get_deepbooru_tags(pil_image, threshold=0.5): - context = get_context('spawn') - with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) - ret = f.result() # will rethrow any exceptions - return ret
\ No newline at end of file + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') |