aboutsummaryrefslogtreecommitdiffstats
path: root/modules/aesthetic_clip.py
diff options
context:
space:
mode:
authorMalumaDev <piano.lu92@gmail.com>2022-10-15 15:59:37 +0200
committerMalumaDev <piano.lu92@gmail.com>2022-10-15 15:59:37 +0200
commit37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 (patch)
treebc23d469afc9f6ef1ecf9a1c15f7554e3d7ff5b5 /modules/aesthetic_clip.py
parentbb57f30c2de46cfca5419ad01738a41705f96cc3 (diff)
downloadstable-diffusion-webui-gfx803-37d7ffb415cd8c69b3c0bb5f61844dde0b169f78.tar.gz
fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text
Diffstat (limited to 'modules/aesthetic_clip.py')
-rw-r--r--modules/aesthetic_clip.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
new file mode 100644
index 00000000..f15cfd47
--- /dev/null
+++ b/modules/aesthetic_clip.py
@@ -0,0 +1,78 @@
+import itertools
+import os
+from pathlib import Path
+import html
+import gc
+
+import gradio as gr
+import torch
+from PIL import Image
+from modules import shared
+from modules.shared import device, aesthetic_embeddings
+from transformers import CLIPModel, CLIPProcessor
+
+from tqdm.auto import tqdm
+
+
+def get_all_images_in_folder(folder):
+ return [os.path.join(folder, f) for f in os.listdir(folder) if
+ os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
+
+
+def check_is_valid_image_file(filename):
+ return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
+
+
+def batched(dataset, total, n=1):
+ for ndx in range(0, total, n):
+ yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
+
+
+def iter_to_batched(iterable, n=1):
+ it = iter(iterable)
+ while True:
+ chunk = tuple(itertools.islice(it, n))
+ if not chunk:
+ return
+ yield chunk
+
+
+def generate_imgs_embd(name, folder, batch_size):
+ # clipModel = CLIPModel.from_pretrained(
+ # shared.sd_model.cond_stage_model.clipModel.name_or_path
+ # )
+ model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
+ processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
+
+ with torch.no_grad():
+ embs = []
+ for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
+ desc=f"Generating embeddings for {name}"):
+ if shared.state.interrupted:
+ break
+ inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
+ outputs = model.get_image_features(**inputs).cpu()
+ embs.append(torch.clone(outputs))
+ inputs.to("cpu")
+ del inputs, outputs
+
+ embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
+
+ # The generated embedding will be located here
+ path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
+ torch.save(embs, path)
+
+ model = model.cpu()
+ del model
+ del processor
+ del embs
+ gc.collect()
+ torch.cuda.empty_cache()
+ res = f"""
+ Done generating embedding for {name}!
+ Hypernetwork saved to {html.escape(path)}
+ """
+ shared.update_aesthetic_embeddings()
+ return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
+ value=sorted(aesthetic_embeddings.keys())[0] if len(
+ aesthetic_embeddings) > 0 else None), res, ""