aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/deepbooru.py26
-rw-r--r--modules/generation_parameters_copypaste.py9
-rw-r--r--modules/hypernetworks/hypernetwork.py8
-rw-r--r--modules/interrogate.py7
-rw-r--r--modules/processing.py4
-rw-r--r--modules/shared.py16
-rw-r--r--modules/textual_inversion/preprocess.py4
-rw-r--r--modules/ui.py54
8 files changed, 96 insertions, 32 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 419e6a9c..f34f3788 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -19,6 +19,7 @@ def get_deepbooru_tags(pil_image):
release_process()
+OPT_INCLUDE_RANKS = "include_ranks"
def create_deepbooru_opts():
from modules import shared
@@ -26,6 +27,7 @@ def create_deepbooru_opts():
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
+ OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
}
@@ -113,6 +115,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
+ include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2]
height = model.input_shape[1]
@@ -151,19 +154,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
if alpha_sort:
sort_ndx = 1
- # sort by reverse by likelihood and normal for alpha
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
- result_tags_out.append(tag)
+ # note: tag_outformat will still have a colon if include_ranks is True
+ tag_outformat = tag.replace(':', ' ')
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
- print('\n'.join(sorted(result_tags_print, reverse=True)))
-
- tags_text = ', '.join(result_tags_out)
+ result_tags_out.append(tag_outformat)
- if use_spaces:
- tags_text = tags_text.replace('_', ' ')
-
- if use_escape:
- tags_text = re.sub(re_special, r'\\\1', tags_text)
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
- return tags_text.replace(':', ' ')
+ return ', '.join(result_tags_out)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index ac1ba7f4..c27826b6 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,5 +1,8 @@
+import os
import re
import gradio as gr
+from modules.shared import script_path
+from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code)
@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(script_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
params = parse_generation_parameters(prompt)
res = []
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b6c06d49..f1248bb7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+
def __init__(self, dim, state_dict=None):
super().__init__()
@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
self.to(devices.device)
def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 635e266e..af858cc0 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -123,7 +123,7 @@ class InterrogateModels:
return caption[0]
- def interrogate(self, pil_image):
+ def interrogate(self, pil_image, include_ranks=False):
res = None
try:
@@ -156,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if include_ranks:
+ res += ", " + match
+ else:
+ res += f", ({match}:{score})"
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/processing.py b/modules/processing.py
index 698b3069..d5172f00 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
assert p.prompt is not None
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
seed = get_fixed_seed(p.seed)
diff --git a/modules/shared.py b/modules/shared.py
index 78b73aae..5901e605 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers
+from modules import sd_samplers, sd_models
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -145,14 +145,14 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = None
- self.show_on_main_page = show_on_main_page
+ self.refresh = refresh
def options_section(section_identifier, options_dict):
@@ -237,8 +237,9 @@ options_templates.update(options_section(('training', "Training"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
- "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -250,14 +251,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
@@ -345,6 +349,8 @@ class Options:
item = self.data_labels.get(key)
item.onchange = func
+ func()
+
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3047bede..886cf0c3 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -17,7 +17,9 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.interrogator.load()
if process_caption_deepbooru:
- deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
diff --git a/modules/ui.py b/modules/ui.py
index b18fe903..d8d886fa 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -79,6 +79,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
+
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -1218,8 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
-
- def create_setting_component(key):
+ def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1239,7 +1240,34 @@ def create_ui(wrap_gradio_gpu_call):
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return comp(label=info.label, value=fun, **(args or {}))
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
+ else:
+ with gr.Row(variant="compact"):
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
+
+ def refresh():
+ info.refresh()
+ refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
+
+ for k, v in refreshed_args.items():
+ setattr(res, k, v)
+
+ return gr.update(**(refreshed_args or {}))
+
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[res],
+ )
+ else:
+ res = comp(label=info.label, value=fun, **(args or {}))
+
+
+ return res
components = []
component_dict = {}
@@ -1313,6 +1341,9 @@ Requested path was: {f}
settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+
quicksettings_list = []
cols_displayed = 0
@@ -1337,7 +1368,7 @@ Requested path was: {f}
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- if item.show_on_main_page:
+ if k in quicksettings_names:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1346,7 +1377,11 @@ Requested path was: {f}
components.append(component)
items_displayed += 1
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ with gr.Row():
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+
request_notifications.click(
fn=lambda: None,
inputs=[],
@@ -1354,10 +1389,6 @@ Requested path was: {f}
_js='function(){}'
)
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
-
def reload_scripts():
modules.scripts.reload_script_body_only()
@@ -1372,7 +1403,6 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
-
restart_gradio.click(
fn=request_restart,
inputs=[],
@@ -1408,12 +1438,12 @@ Requested path was: {f}
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
- component = create_setting_component(k)
+ component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
settings_interface.gradio_ref = demo
- with gr.Tabs() as tabs:
+ with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()