diff options
author | NoCrypt <57245077+NoCrypt@users.noreply.github.com> | 2022-11-11 14:14:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-11 14:14:10 +0000 |
commit | 6165f07e74f05543bf9039dda5d66686d18d985a (patch) | |
tree | 8b5d06daf8ed9626e2ac7f872720b87e1207bced | |
parent | c556d34523e8764bd66bf6a7bf97d06add420020 (diff) | |
parent | e666220ee458ae1e80a2ba12c64a0da9d68f20a2 (diff) | |
download | stable-diffusion-webui-gfx803-6165f07e74f05543bf9039dda5d66686d18d985a.tar.gz stable-diffusion-webui-gfx803-6165f07e74f05543bf9039dda5d66686d18d985a.tar.bz2 stable-diffusion-webui-gfx803-6165f07e74f05543bf9039dda5d66686d18d985a.zip |
Merge branch 'master' into patch-1
-rw-r--r-- | javascript/generationParams.js | 33 | ||||
-rw-r--r-- | modules/api/api.py | 16 | ||||
-rw-r--r-- | modules/api/models.py | 1 | ||||
-rw-r--r-- | modules/ngrok.py | 13 | ||||
-rw-r--r-- | modules/sd_models.py | 29 | ||||
-rw-r--r-- | modules/shared.py | 2 | ||||
-rw-r--r-- | modules/textual_inversion/dataset.py | 7 | ||||
-rw-r--r-- | modules/ui.py | 22 | ||||
-rw-r--r-- | scripts/prompt_matrix.py | 2 | ||||
-rw-r--r-- | scripts/prompts_from_file.py | 6 |
10 files changed, 114 insertions, 17 deletions
diff --git a/javascript/generationParams.js b/javascript/generationParams.js new file mode 100644 index 00000000..95f05093 --- /dev/null +++ b/javascript/generationParams.js @@ -0,0 +1,33 @@ +// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes + +let txt2img_gallery, img2img_gallery, modal = undefined; +onUiUpdate(function(){ + if (!txt2img_gallery) { + txt2img_gallery = attachGalleryListeners("txt2img") + } + if (!img2img_gallery) { + img2img_gallery = attachGalleryListeners("img2img") + } + if (!modal) { + modal = gradioApp().getElementById('lightboxModal') + modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); + } +}); + +let modalObserver = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText + if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') + gradioApp().getElementById(selectedTab+"_generation_info_button").click() + }); +}); + +function attachGalleryListeners(tab_name) { + gallery = gradioApp().querySelector('#'+tab_name+'_gallery') + gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); + gallery?.addEventListener('keydown', (e) => { + if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow + gradioApp().getElementById(tab_name+"_generation_info_button").click() + }); + return gallery; +} diff --git a/modules/api/api.py b/modules/api/api.py index 688469ad..596a6616 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -15,6 +15,9 @@ from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List +if shared.cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags + def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) @@ -220,11 +223,20 @@ class Api: if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") - img = self.__base64_to_image(image_b64) + img = decode_base64_to_image(image_b64) + img = img.convert('RGB') # Override object param with self.queue_lock: - processed = shared.interrogator.interrogate(img) + if interrogatereq.model == "clip": + processed = shared.interrogator.interrogate(img) + elif interrogatereq.model == "deepdanbooru": + if shared.cmd_opts.deepdanbooru: + processed = get_deepbooru_tags(img) + else: + raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.") + else: + raise HTTPException(status_code=404, detail="Model not found") return InterrogateResponse(caption=processed) diff --git a/modules/api/models.py b/modules/api/models.py index 34dbfa16..f9cd929e 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -170,6 +170,7 @@ class ProgressResponse(BaseModel): class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + model: str = Field(default="clip", title="Model", description="The interrogate model used.") class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") diff --git a/modules/ngrok.py b/modules/ngrok.py index 25c53af8..64c9a3c2 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -1,14 +1,23 @@ from pyngrok import ngrok, conf, exception - def connect(token, port, region): + account = None if token == None: token = 'None' + else: + if ':' in token: + # token = authtoken:username:password + account = token.split(':')[1] + ':' + token.split(':')[-1] + token = token.split(':')[0] + config = conf.PyngrokConfig( auth_token=token, region=region ) try: - public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url + if account == None: + public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url + else: + public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') diff --git a/modules/sd_models.py b/modules/sd_models.py index 34c57bfa..80addf03 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
+ cache_enabled = shared.opts.sd_checkpoint_cache > 0
+
+ if cache_enabled:
sd_vae.restore_base_vae(model)
- checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- if checkpoint_info not in checkpoints_loaded:
+ if cache_enabled and checkpoint_info in checkpoints_loaded:
+ # use checkpoint cache
+ vae_name = sd_vae.get_filename(vae_file) if vae_file else None
+ vae_message = f" with {vae_name} VAE" if vae_name else ""
+ print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
+ else:
+ # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): del pl_sd
model.load_state_dict(sd, strict=False)
del sd
+
+ if cache_enabled:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@@ -199,14 +211,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.first_stage_model.to(devices.dtype_vae)
- else:
- vae_name = sd_vae.get_filename(vae_file) if vae_file else None
- vae_message = f" with {vae_name} VAE" if vae_name else ""
- print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
-
- if shared.opts.sd_checkpoint_cache > 0:
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ # clean up cache if limit is reached
+ if cache_enabled:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c..caabf078 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -319,6 +319,8 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
+ "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index ad726577..eb75c376 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -98,7 +98,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
+ tags = filename_text.split(',')
+ if shared.opts.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
+ if shared.opts.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
return text
def __len__(self):
diff --git a/modules/ui.py b/modules/ui.py index 7ea1177f..5dce7f3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -566,6 +566,19 @@ def apply_setting(key, value): return value
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
+
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -638,6 +651,15 @@ Requested path was: {f} with gr.Group():
html_info = gr.HTML()
generation_info = gr.Textbox(visible=False)
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
save.click(
fn=wrap_gradio_call(save_files),
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index e49c9b20..4d1e152d 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -80,6 +80,8 @@ class Script(scripts.Script): grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
processed.images.insert(0, grid)
+ processed.index_of_first_image = 1
+ processed.infotexts.insert(0, processed.infotexts[0])
if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 3388bc77..32fe6bdb 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -145,6 +145,8 @@ class Script(scripts.Script): state.job_count = job_count
images = []
+ all_prompts = []
+ infotexts = []
for n, args in enumerate(jobs):
state.job = f"{state.job_no + 1} out of {state.job_count}"
@@ -157,5 +159,7 @@ class Script(scripts.Script): if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter)
+ all_prompts += proc.all_prompts
+ infotexts += proc.infotexts
- return Processed(p, images, p.seed, "")
+ return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
|