diff options
-rw-r--r-- | extensions-builtin/Lora/network_oft.py | 37 | ||||
-rw-r--r-- | extensions-builtin/extra-options-section/scripts/extra_options_section.py | 5 | ||||
-rw-r--r-- | javascript/ui.js | 24 | ||||
-rw-r--r-- | modules/shared_options.py | 1 | ||||
-rw-r--r-- | modules/styles.py | 54 | ||||
-rw-r--r-- | modules/xpu_specific.py | 9 |
6 files changed, 58 insertions, 72 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c37811..fa647020 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - return updown, ex_bias diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df62..ac2c3de4 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = []
self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
+ elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
- with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
+ with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
@@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i """),
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
- "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(),
+ "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
}))
diff --git a/javascript/ui.js b/javascript/ui.js index 410fc44e..18c9f891 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -215,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d01..acb6e2d4 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -256,6 +256,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
}))
diff --git a/modules/styles.py b/modules/styles.py index 4d218cd7..81d9800d 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -2,7 +2,6 @@ import csv import fnmatch
import os
import os.path
-import re
import typing
import shutil
@@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple): path: str = None
-def clean_text(text: str) -> str:
- """
- Iterating through a list of regular expressions and replacement strings, we
- clean up the prompt and style text to make it easier to match against each
- other.
- """
- re_list = [
- ("multiple commas", re.compile("(,+\s+)+,?"), ", "),
- ("multiple spaces", re.compile("\s{2,}"), " "),
- ]
- for _, regex, replace in re_list:
- text = regex.sub(replace, text)
-
- return text.strip(", ")
-
-
def merge_prompts(style_prompt: str, prompt: str) -> str:
if "{prompt}" in style_prompt:
res = style_prompt.replace("{prompt}", prompt)
@@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles): for style in styles:
prompt = merge_prompts(style, prompt)
- return clean_text(prompt)
+ return prompt
def unwrap_style_text_from_prompt(style_text, prompt):
@@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt): Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified.
"""
- stripped_prompt = clean_text(prompt)
- stripped_style_text = clean_text(style_text)
+ stripped_prompt = prompt
+ stripped_style_text = style_text
if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text that isn't part of the style.
@@ -115,10 +98,8 @@ class StyleDatabase: self.path = path
folder, file = os.path.split(self.path)
- self.default_file = file.split("*")[0] + ".csv"
- if self.default_file == ".csv":
- self.default_file = "styles.csv"
- self.default_path = os.path.join(folder, self.default_file)
+ filename, _, ext = file.partition('*')
+ self.default_path = os.path.join(folder, filename + ext)
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -172,10 +153,8 @@ class StyleDatabase: row["name"], prompt, negative_prompt, path
)
- def get_style_paths(self) -> list():
- """
- Returns a list of all distinct paths, including the default path, of
- files that styles are loaded from."""
+ def get_style_paths(self) -> set:
+ """Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
@@ -189,9 +168,9 @@ class StyleDatabase: style_paths.add(style.path)
# Remove any paths for styles that are just list dividers
- style_paths.remove("do_not_save")
+ style_paths.discard("do_not_save")
- return list(style_paths)
+ return style_paths
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -213,20 +192,7 @@ class StyleDatabase: # The path argument is deprecated, but kept for backwards compatibility
_ = path
- # Update any styles without a path to the default path
- for style in list(self.styles.values()):
- if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
-
- # Create a list of all distinct paths, including the default path
- style_paths = set()
- style_paths.add(self.default_path)
- for _, style in self.styles.items():
- if style.path:
- style_paths.add(style.path)
-
- # Remove any paths for styles that are just list dividers
- style_paths.remove("do_not_save")
+ style_paths = self.get_style_paths()
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790..d8da94a0 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,12 @@ if has_xpu: CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) |