diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/deepbooru.py | 26 | ||||
-rw-r--r-- | modules/generation_parameters_copypaste.py | 9 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 8 | ||||
-rw-r--r-- | modules/interrogate.py | 7 | ||||
-rw-r--r-- | modules/processing.py | 4 | ||||
-rw-r--r-- | modules/shared.py | 16 | ||||
-rw-r--r-- | modules/textual_inversion/preprocess.py | 4 | ||||
-rw-r--r-- | modules/ui.py | 54 |
8 files changed, 96 insertions, 32 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 419e6a9c..f34f3788 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -19,6 +19,7 @@ def get_deepbooru_tags(pil_image): release_process() +OPT_INCLUDE_RANKS = "include_ranks" def create_deepbooru_opts(): from modules import shared @@ -26,6 +27,7 @@ def create_deepbooru_opts(): "use_spaces": shared.opts.deepbooru_use_spaces, "use_escape": shared.opts.deepbooru_escape, "alpha_sort": shared.opts.deepbooru_sort_alpha, + OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks, } @@ -113,6 +115,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o alpha_sort = deepbooru_opts['alpha_sort'] use_spaces = deepbooru_opts['use_spaces'] use_escape = deepbooru_opts['use_escape'] + include_ranks = deepbooru_opts['include_ranks'] width = model.input_shape[2] height = model.input_shape[1] @@ -151,19 +154,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o if alpha_sort: sort_ndx = 1 - # sort by reverse by likelihood and normal for alpha + # sort by reverse by likelihood and normal for alpha, and format tag text as requested unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) for weight, tag in unsorted_tags_in_theshold: - result_tags_out.append(tag) + # note: tag_outformat will still have a colon if include_ranks is True + tag_outformat = tag.replace(':', ' ') + if use_spaces: + tag_outformat = tag_outformat.replace('_', ' ') + if use_escape: + tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) + if include_ranks: + tag_outformat = f"({tag_outformat}:{weight:.3f})" - print('\n'.join(sorted(result_tags_print, reverse=True))) - - tags_text = ', '.join(result_tags_out) + result_tags_out.append(tag_outformat) - if use_spaces: - tags_text = tags_text.replace('_', ' ') - - if use_escape: - tags_text = re.sub(re_special, r'\\\1', tags_text) + print('\n'.join(sorted(result_tags_print, reverse=True))) - return tags_text.replace(':', ' ') + return ', '.join(result_tags_out) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ac1ba7f4..c27826b6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,8 @@ +import os
import re
import gradio as gr
+from modules.shared import script_path
+from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code)
@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(script_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
params = parse_generation_parameters(prompt)
res = []
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b6c06d49..f1248bb7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+
def __init__(self, dim, state_dict=None):
super().__init__()
@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): self.to(devices.device)
def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork:
diff --git a/modules/interrogate.py b/modules/interrogate.py index 635e266e..af858cc0 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -123,7 +123,7 @@ class InterrogateModels: return caption[0]
- def interrogate(self, pil_image):
+ def interrogate(self, pil_image, include_ranks=False):
res = None
try:
@@ -156,7 +156,10 @@ class InterrogateModels: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if include_ranks:
+ res += ", " + match
+ else:
+ res += f", ({match}:{score})"
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/processing.py b/modules/processing.py index 698b3069..d5172f00 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else:
assert p.prompt is not None
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
seed = get_fixed_seed(p.seed)
diff --git a/modules/shared.py b/modules/shared.py index 78b73aae..5901e605 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers
+from modules import sd_samplers, sd_models
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -145,14 +145,14 @@ def realesrgan_models_names(): class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = None
- self.show_on_main_page = show_on_main_page
+ self.refresh = refresh
def options_section(section_identifier, options_dict):
@@ -237,8 +237,9 @@ options_templates.update(options_section(('training', "Training"), { }))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
- "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -250,14 +251,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
@@ -345,6 +349,8 @@ class Options: item = self.data_labels.get(key)
item.onchange = func
+ func()
+
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3047bede..886cf0c3 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -17,7 +17,9 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.interrogator.load()
if process_caption_deepbooru:
- deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
diff --git a/modules/ui.py b/modules/ui.py index b18fe903..d8d886fa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -79,6 +79,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️ art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
+
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -1218,8 +1220,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[],
)
-
- def create_setting_component(key):
+ def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1239,7 +1240,34 @@ def create_ui(wrap_gradio_gpu_call): else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return comp(label=info.label, value=fun, **(args or {}))
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
+ else:
+ with gr.Row(variant="compact"):
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
+
+ def refresh():
+ info.refresh()
+ refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
+
+ for k, v in refreshed_args.items():
+ setattr(res, k, v)
+
+ return gr.update(**(refreshed_args or {}))
+
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[res],
+ )
+ else:
+ res = comp(label=info.label, value=fun, **(args or {}))
+
+
+ return res
components = []
component_dict = {}
@@ -1313,6 +1341,9 @@ Requested path was: {f} settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+
quicksettings_list = []
cols_displayed = 0
@@ -1337,7 +1368,7 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- if item.show_on_main_page:
+ if k in quicksettings_names:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1346,7 +1377,11 @@ Requested path was: {f} components.append(component)
items_displayed += 1
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ with gr.Row():
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+
request_notifications.click(
fn=lambda: None,
inputs=[],
@@ -1354,10 +1389,6 @@ Requested path was: {f} _js='function(){}'
)
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
-
def reload_scripts():
modules.scripts.reload_script_body_only()
@@ -1372,7 +1403,6 @@ Requested path was: {f} shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
-
restart_gradio.click(
fn=request_restart,
inputs=[],
@@ -1408,12 +1438,12 @@ Requested path was: {f} with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
- component = create_setting_component(k)
+ component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
settings_interface.gradio_ref = demo
- with gr.Tabs() as tabs:
+ with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
|