diff options
author | MalumaDev <piano.lu92@gmail.com> | 2022-10-15 15:59:37 +0200 |
---|---|---|
committer | MalumaDev <piano.lu92@gmail.com> | 2022-10-15 15:59:37 +0200 |
commit | 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 (patch) | |
tree | bc23d469afc9f6ef1ecf9a1c15f7554e3d7ff5b5 /modules/aesthetic_clip.py | |
parent | bb57f30c2de46cfca5419ad01738a41705f96cc3 (diff) | |
download | stable-diffusion-webui-gfx803-37d7ffb415cd8c69b3c0bb5f61844dde0b169f78.tar.gz |
fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text
Diffstat (limited to 'modules/aesthetic_clip.py')
-rw-r--r-- | modules/aesthetic_clip.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py new file mode 100644 index 00000000..f15cfd47 --- /dev/null +++ b/modules/aesthetic_clip.py @@ -0,0 +1,78 @@ +import itertools +import os +from pathlib import Path +import html +import gc + +import gradio as gr +import torch +from PIL import Image +from modules import shared +from modules.shared import device, aesthetic_embeddings +from transformers import CLIPModel, CLIPProcessor + +from tqdm.auto import tqdm + + +def get_all_images_in_folder(folder): + return [os.path.join(folder, f) for f in os.listdir(folder) if + os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] + + +def check_is_valid_image_file(filename): + return filename.lower().endswith(('.png', '.jpg', '.jpeg')) + + +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def iter_to_batched(iterable, n=1): + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +def generate_imgs_embd(name, folder, batch_size): + # clipModel = CLIPModel.from_pretrained( + # shared.sd_model.cond_stage_model.clipModel.name_or_path + # ) + model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) + processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + + with torch.no_grad(): + embs = [] + for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), + desc=f"Generating embeddings for {name}"): + if shared.state.interrupted: + break + inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) + outputs = model.get_image_features(**inputs).cpu() + embs.append(torch.clone(outputs)) + inputs.to("cpu") + del inputs, outputs + + embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) + + # The generated embedding will be located here + path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") + torch.save(embs, path) + + model = model.cpu() + del model + del processor + del embs + gc.collect() + torch.cuda.empty_cache() + res = f""" + Done generating embedding for {name}! + Hypernetwork saved to {html.escape(path)} + """ + shared.update_aesthetic_embeddings() + return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", + value=sorted(aesthetic_embeddings.keys())[0] if len( + aesthetic_embeddings) > 0 else None), res, "" |