diff options
author | Greendayle <Greendayle> | 2022-10-05 18:50:10 +0000 |
---|---|---|
committer | Greendayle <Greendayle> | 2022-10-05 18:55:26 +0000 |
commit | 59a2b9e5afc27d2fda72069ca0635070535d18fe (patch) | |
tree | f63afb4de3427072e81212131488f42be0fcbef4 /modules/deepbooru.py | |
parent | 1eb588cbf19924333b88beaa1ac0041904966640 (diff) | |
download | stable-diffusion-webui-gfx803-59a2b9e5afc27d2fda72069ca0635070535d18fe.tar.gz stable-diffusion-webui-gfx803-59a2b9e5afc27d2fda72069ca0635070535d18fe.tar.bz2 stable-diffusion-webui-gfx803-59a2b9e5afc27d2fda72069ca0635070535d18fe.zip |
deepdanbooru interrogator
Diffstat (limited to 'modules/deepbooru.py')
-rw-r--r-- | modules/deepbooru.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py new file mode 100644 index 00000000..958b1c3d --- /dev/null +++ b/modules/deepbooru.py @@ -0,0 +1,60 @@ +import os.path +from concurrent.futures import ProcessPoolExecutor + +import numpy as np +import deepdanbooru as dd +import tensorflow as tf + + +def _load_tf_and_return_tags(pil_image, threshold): + this_folder = os.path.dirname(__file__) + model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') + if not os.path.exists(model_path): + return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru" + + tags = dd.project.load_tags_from_project(model_path) + model = dd.project.load_model_from_project( + model_path, compile_model=True + ) + + width = model.input_shape[2] + height = model.input_shape[1] + image = np.array(pil_image) + image = tf.image.resize( + image, + size=(height, width), + method=tf.image.ResizeMethod.AREA, + preserve_aspect_ratio=True, + ) + image = image.numpy() # EagerTensor to np.array + image = dd.image.transform_and_pad_image(image, width, height) + image = image / 255.0 + image_shape = image.shape + image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) + + y = model.predict(image)[0] + + result_dict = {} + + for i, tag in enumerate(tags): + result_dict[tag] = y[i] + + + + result_tags_out = [] + result_tags_print = [] + for tag in tags: + if result_dict[tag] >= threshold: + result_tags_out.append(tag) + result_tags_print.append(f'{result_dict[tag]} {tag}') + + print('\n'.join(sorted(result_tags_print, reverse=True))) + + return ', '.join(result_tags_out) + + +def get_deepbooru_tags(pil_image, threshold=0.5): + with ProcessPoolExecutor() as executor: + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold) + ret = f.result() # will rethrow any exceptions + return ret
\ No newline at end of file |