From 2abd89acc66419abf2eee9b03fd093f2737670de Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 20:04:35 +0300 Subject: index on master: 91c8d0d Merge pull request #7231 from EllangoK/master --- extensions-builtin/Lora/lora.py | 21 +++++++++++++++++++-- extensions-builtin/Lora/scripts/lora_script.py | 5 +++++ 2 files changed, 24 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36..568a7675 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -12,7 +12,7 @@ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+) re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") -def convert_diffusers_name_to_compvis(key): +def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex): r = re.match(regex, key) if not r: @@ -34,6 +34,14 @@ def convert_diffusers_name_to_compvis(key): return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" if match(m, re_text_block): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + elif 'self_attn': + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return key @@ -83,9 +91,10 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) keys_failed_to_match = [] + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers) + fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) key, lora_key = fullkey.split(".", 1) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) @@ -104,9 +113,13 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + continue assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' with torch.no_grad(): @@ -182,6 +195,10 @@ def lora_Conv2d_forward(self, input): return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) +def lora_NonDynamicallyQuantizableLinear_forward(self, input): + return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) + + def list_available_loras(): available_loras.clear() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160..a385ae94 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -10,6 +10,7 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora def before_ui(): @@ -23,8 +24,12 @@ if not hasattr(torch.nn, 'Linear_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'NonDynamicallyQuantizableLinear_forward_before_lora'): + torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora = torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward + torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.modules.linear.NonDynamicallyQuantizableLinear.forward = lora.lora_NonDynamicallyQuantizableLinear_forward script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) -- cgit v1.2.3 From 04924241218bb51bee255bebc6c66ef1de449f4a Mon Sep 17 00:00:00 2001 From: bluelovers Date: Sun, 12 Mar 2023 10:18:33 +0800 Subject: feat: try sort as ignore-case https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8368 --- extensions-builtin/Lora/lora.py | 2 +- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36..7d3c0f90 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -192,7 +192,7 @@ def list_available_loras(): glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) - for filename in sorted(candidates): + for filename in sorted(candidates, key=str.lower): if os.path.isdir(filename): continue diff --git a/modules/shared.py b/modules/shared.py index 805f9cc1..1322b96d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -702,7 +702,7 @@ mem_mon.start() def listfiles(dirname): - filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c63c7d1d..3d21b9fe 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -129,7 +129,7 @@ class EmbeddingDatabase: if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] - self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) + self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True, cmp=lambda x, y: x.lower() > y.lower()) return embedding @@ -196,7 +196,7 @@ class EmbeddingDatabase: return for root, dirs, fns in os.walk(embdir.path, followlinks=True): - for fn in fns: + for fn in sorted(fns, key=str.lower): try: fullfn = os.path.join(root, fn) -- cgit v1.2.3 From 2f0181405f25e1448a55697081e380020fe8c68d Mon Sep 17 00:00:00 2001 From: FNSpd <125805478+FNSpd@users.noreply.github.com> Date: Tue, 21 Mar 2023 14:53:51 +0400 Subject: Update lora.py --- extensions-builtin/Lora/lora.py | 1 + 1 file changed, 1 insertion(+) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 8937b585..7c371deb 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -178,6 +178,7 @@ def load_loras(names, multipliers=None): def lora_forward(module, input, res): + input = devices.cond_cast_unet(input) if len(loaded_loras) == 0: return res -- cgit v1.2.3 From 9f0da9f6edfb9be1d69ba3492a61d96db769307b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 20 Mar 2023 16:09:36 +0300 Subject: initial gradio 3.22 support --- .../javascript/prompt-bracket-checker.js | 15 +- javascript/hints.js | 2 +- javascript/imageviewer.js | 70 +-- javascript/progressbar.js | 67 +-- javascript/ui.js | 3 +- modules/scripts.py | 3 + modules/scripts_postprocessing.py | 2 +- modules/ui.py | 43 +- modules/ui_common.py | 3 +- modules/ui_components.py | 36 +- requirements.txt | 2 +- requirements_versions.txt | 2 +- script.js | 6 +- scripts/postprocessing_upscale.py | 34 +- style.css | 626 ++++++--------------- 15 files changed, 289 insertions(+), 625 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 4a85c8eb..f0918e26 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -89,22 +89,15 @@ function checkBrackets(evt, textArea, counterElt) { function setupBracketChecking(id_prompt, id_counter){ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var counter = gradioApp().getElementById(id_counter) + textarea.addEventListener("input", function(evt){ checkBrackets(evt, textarea, counter) }); } -var shadowRootLoaded = setInterval(function() { - var shadowRoot = document.querySelector('gradio-app').shadowRoot; - if(! shadowRoot) return false; - - var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) return false; - - clearInterval(shadowRootLoaded); - +onUiLoaded(function(){ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') - setupBracketChecking('img2img_prompt', 'imgimg_token_counter') + setupBracketChecking('img2img_prompt', 'img2img_token_counter') setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') -}, 1000); +}) \ No newline at end of file diff --git a/javascript/hints.js b/javascript/hints.js index 7f4101b2..61763e6b 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -18,7 +18,7 @@ titles = { "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "\u{1f4be}": "Save style", - "\u{1f5d1}": "Clear prompt", + "\u{1f5d1}\ufe0f": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", "\u{1f3b4}": "Show extra networks", diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 28e748b7..7547e771 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -50,7 +50,7 @@ function updateOnBackgroundChange() { } function modalImageSwitch(offset) { - var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") + var allgalleryButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item") var galleryButtons = [] allgalleryButtons.forEach(function(elem) { if (elem.parentElement.offsetParent) { @@ -59,7 +59,7 @@ function modalImageSwitch(offset) { }) if (galleryButtons.length > 1) { - var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") + var allcurrentButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item.selected") var currentButton = null allcurrentButtons.forEach(function(elem) { if (elem.parentElement.offsetParent) { @@ -136,37 +136,29 @@ function modalKeyHandler(event) { } } -function showGalleryImage() { - setTimeout(function() { - fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain') - - if (fullImg_preview != null) { - fullImg_preview.forEach(function function_name(e) { - if (e.dataset.modded) - return; - e.dataset.modded = true; - if(e && e.parentElement.tagName == 'DIV'){ - e.style.cursor='pointer' - e.style.userSelect='none' - - var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 - - // For Firefox, listening on click first switched to next image then shows the lightbox. - // If you know how to fix this without switching to mousedown event, please. - // For other browsers the event is click to make it possiblr to drag picture. - var event = isFirefox ? 'mousedown' : 'click' - - e.addEventListener(event, function (evt) { - if(!opts.js_modal_lightbox || evt.button != 0) return; - modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) - evt.preventDefault() - showModal(evt) - }, true); - } - }); - } +function setupImageForLightbox(e) { + if (e.dataset.modded) + return; + + e.dataset.modded = true; + e.style.cursor='pointer' + e.style.userSelect='none' + + var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 + + // For Firefox, listening on click first switched to next image then shows the lightbox. + // If you know how to fix this without switching to mousedown event, please. + // For other browsers the event is click to make it possiblr to drag picture. + var event = isFirefox ? 'mousedown' : 'click' + + e.addEventListener(event, function (evt) { + if(!opts.js_modal_lightbox || evt.button != 0) return; + + modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) + evt.preventDefault() + showModal(evt) + }, true); - }, 100); } function modalZoomSet(modalImage, enable) { @@ -199,21 +191,21 @@ function modalTileImageToggle(event) { } function galleryImageHandler(e) { - if (e && e.parentElement.tagName == 'BUTTON') { + //if (e && e.parentElement.tagName == 'BUTTON') { e.onclick = showGalleryImage; - } + //} } onUiUpdate(function() { - fullImg_preview = gradioApp().querySelectorAll('img.w-full') + fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') if (fullImg_preview != null) { - fullImg_preview.forEach(galleryImageHandler); + fullImg_preview.forEach(setupImageForLightbox); } updateOnBackgroundChange(); }) document.addEventListener("DOMContentLoaded", function() { - const modalFragment = document.createDocumentFragment(); + //const modalFragment = document.createDocumentFragment(); const modal = document.createElement('div') modal.onclick = closeModal; modal.id = "lightboxModal"; @@ -277,9 +269,9 @@ document.addEventListener("DOMContentLoaded", function() { modal.appendChild(modalNext) + gradioApp().appendChild(modal) - gradioApp().getRootNode().appendChild(modal) - document.body.appendChild(modalFragment); + document.body.appendChild(modal); }); diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 9ccc9da4..4ac9b8db 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,78 +1,13 @@ // code related to showing and updating progressbar shown as the image is being made - -galleries = {} -storedGallerySelections = {} -galleryObservers = {} - function rememberGallerySelection(id_gallery){ - storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery) -} -function getGallerySelectedIndex(id_gallery){ - let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') - let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') - - let currentlySelectedIndex = -1 - galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } }) - - return currentlySelectedIndex } -// this is a workaround for https://github.com/gradio-app/gradio/issues/2984 -function check_gallery(id_gallery){ - let gallery = gradioApp().getElementById(id_gallery) - // if gallery has no change, no need to setting up observer again. - if (gallery && galleries[id_gallery] !== gallery){ - galleries[id_gallery] = gallery; - if(galleryObservers[id_gallery]){ - galleryObservers[id_gallery].disconnect(); - } +function getGallerySelectedIndex(id_gallery){ - storedGallerySelections[id_gallery] = -1 - - galleryObservers[id_gallery] = new MutationObserver(function (){ - let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') - let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') - let currentlySelectedIndex = getGallerySelectedIndex(id_gallery) - prevSelectedIndex = storedGallerySelections[id_gallery] - storedGallerySelections[id_gallery] = -1 - - if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { - // automatically re-open previously selected index (if exists) - activeElement = gradioApp().activeElement; - let scrollX = window.scrollX; - let scrollY = window.scrollY; - - galleryButtons[prevSelectedIndex].click(); - showGalleryImage(); - - // When the gallery button is clicked, it gains focus and scrolls itself into view - // We need to scroll back to the previous position - setTimeout(function (){ - window.scrollTo(scrollX, scrollY); - }, 50); - - if(activeElement){ - // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it - // if someone has a better solution please by all means - setTimeout(function (){ - activeElement.focus({ - preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it - }) - }, 1); - } - } - }) - galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) - } } -onUiUpdate(function(){ - check_gallery('txt2img_gallery') - check_gallery('img2img_gallery') -}) - function request(url, data, handler, errorHandler){ var xhr = new XMLHttpRequest(); var url = url; diff --git a/javascript/ui.js b/javascript/ui.js index b7a8268a..fcaf5608 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -86,7 +86,7 @@ function get_tab_index(tabId){ var res = 0 gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){ - if(button.className.indexOf('bg-white') != -1) + if(button.className.indexOf('selected') != -1) res = i }) @@ -255,7 +255,6 @@ onUiUpdate(function(){ } prompt.parentElement.insertBefore(counter, prompt) - counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); } diff --git a/modules/scripts.py b/modules/scripts.py index 8de19884..40d8dcc6 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -521,6 +521,9 @@ def IOComponent_init(self, *args, **kwargs): res = original_IOComponent_init(self, *args, **kwargs) + # this adds gradio-* to every component for css styling (ie gradio-button to gr.Button) + self.elem_classes = ["gradio-" + self.get_block_name(), *(self.elem_classes or [])] + script_callbacks.after_component_callback(self, **kwargs) if scripts_current is not None: diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index ce0ebb61..b11568c0 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -109,7 +109,7 @@ class ScriptPostprocessingRunner: inputs = [] for script in self.scripts_in_preferred_order(): - with gr.Box() as group: + with gr.Row() as group: self.create_script_ui(script, inputs) script.group = group diff --git a/modules/ui.py b/modules/ui.py index 7e603332..80807ce3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts, restricted_opts @@ -89,7 +89,7 @@ paste_symbol = '\u2199\ufe0f' # ↙ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ +clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ @@ -179,14 +179,13 @@ def interrogate_deepbooru(image): def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): + with FormRow(elem_id=target_interface + '_seed_row', variant="compact"): seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed') - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) # Components to show/hide based on the 'Extra' checkbox seed_extras = [] @@ -195,8 +194,8 @@ def create_seed_inputs(target_interface): seed_extras.append(seed_extra_row_1) subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed') subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') with FormRow(visible=False) as seed_extra_row_2: @@ -291,19 +290,19 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") button_interrogate = None button_deepbooru = None if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): + with gr.Column(scale=1, elem_classes="interrogate-col"): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box"): - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") + skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') skip.click( @@ -325,9 +324,9 @@ def create_toprow(is_img2img): prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") + negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") clear_prompt_button.click( @@ -479,7 +478,9 @@ def create_ui(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") + with gr.Column(elem_id="txt2img_dimensions_row", scale=1): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") + if opts.dimensions_and_batch_together: with gr.Column(elem_id="txt2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") @@ -492,7 +493,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes", variant="compact"): + with FormRow(elem_classes="checkboxes-row", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") @@ -757,7 +758,9 @@ def create_ui(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") + with gr.Column(elem_id="img2img_dimensions_row", scale=1): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") + if opts.dimensions_and_batch_together: with gr.Column(elem_id="img2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") @@ -774,7 +777,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes", variant="compact"): + with FormRow(elem_classes="checkboxes-row", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") diff --git a/modules/ui_common.py b/modules/ui_common.py index a12433d2..d4e00829 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -130,7 +130,7 @@ Requested path was: {f} generation_info = None with gr.Column(): with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config) if tabname != "extras": save = gr.Button('Save', elem_id=f'save_{tabname}') @@ -160,6 +160,7 @@ Requested path was: {f} _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", inputs=[generation_info, html_info, html_info], outputs=[html_info, html_info], + show_progress=False, ) save.click( diff --git a/modules/ui_components.py b/modules/ui_components.py index 284ca0cf..2b1da2cb 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,55 +1,61 @@ import gradio as gr -class ToolButton(gr.Button, gr.components.FormComponent): - """Small button with single emoji as text, fits inside gradio forms""" +class FormComponent: + def get_expected_parent(self): + return gr.components.Form - def __init__(self, **kwargs): - super().__init__(variant="tool", **kwargs) - def get_block_name(self): - return "button" +gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent -class ToolButtonTop(gr.Button, gr.components.FormComponent): - """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" +class ToolButton(FormComponent, gr.Button): + """Small button with single emoji as text, fits inside gradio forms""" - def __init__(self, **kwargs): - super().__init__(variant="tool-top", **kwargs) + def __init__(self, *args, **kwargs): + classes = kwargs.pop("elem_classes", []) + super().__init__(*args, elem_classes=["tool", *classes], **kwargs) def get_block_name(self): return "button" -class FormRow(gr.Row, gr.components.FormComponent): +class FormRow(FormComponent, gr.Row): """Same as gr.Row but fits inside gradio forms""" def get_block_name(self): return "row" -class FormGroup(gr.Group, gr.components.FormComponent): +class FormColumn(FormComponent, gr.Column): + """Same as gr.Column but fits inside gradio forms""" + + def get_block_name(self): + return "column" + + +class FormGroup(FormComponent, gr.Group): """Same as gr.Row but fits inside gradio forms""" def get_block_name(self): return "group" -class FormHTML(gr.HTML, gr.components.FormComponent): +class FormHTML(FormComponent, gr.HTML): """Same as gr.HTML but fits inside gradio forms""" def get_block_name(self): return "html" -class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): +class FormColorPicker(FormComponent, gr.ColorPicker): """Same as gr.ColorPicker but fits inside gradio forms""" def get_block_name(self): return "colorpicker" -class DropdownMulti(gr.Dropdown): +class DropdownMulti(FormComponent, gr.Dropdown): """Same as gr.Dropdown but always multiselect""" def __init__(self, **kwargs): super().__init__(multiselect=True, **kwargs) diff --git a/requirements.txt b/requirements.txt index 6d53f089..e71251c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ basicsr fonts font-roboto gfpgan -gradio==3.16.2 +gradio==3.22.1 invisible-watermark numpy omegaconf diff --git a/requirements_versions.txt b/requirements_versions.txt index 0031c616..ab16e4cc 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -3,7 +3,7 @@ transformers==4.25.1 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 -gradio==3.16.2 +gradio==3.22.1 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 diff --git a/script.js b/script.js index 97e0bfcf..978b948f 100644 --- a/script.js +++ b/script.js @@ -1,7 +1,9 @@ function gradioApp() { const elems = document.getElementsByTagName('gradio-app') - const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot - return !!gradioShadowRoot ? gradioShadowRoot : document; + const elem = elems.length == 0 ? document : elems[0] + + elem.getElementById = function(id){ return document.getElementById(id) } + return elem.shadowRoot ? elem.shadowRoot : elem } function get_uiCurrentTab() { diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 8842bd91..11eab31a 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -17,22 +17,24 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): def ui(self): selected_tab = gr.State(value=0) - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: - with FormRow(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with FormRow(): - extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - - with FormRow(): - extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + with gr.Column(): + with FormRow(): + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) diff --git a/style.css b/style.css index 3eac2b17..936930fd 100644 --- a/style.css +++ b/style.css @@ -1,270 +1,259 @@ -.container { - max-width: 100%; -} -.token-counter{ - position: absolute; - display: inline-block; - right: 2em; - min-width: 0 !important; - width: auto; - z-index: 100; -} +/* general gradio fixes */ -.token-counter.error span{ - box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); - border: 2px solid rgba(255,0,0,0.4) !important; +:root{ + --checkbox-label-gap: 0.25em 0.1em; + --section-header-text-size: 12pt; } -.token-counter div{ - display: inline; +.block.padded{ + padding: 0.2em 0.5em !important; } -.token-counter span{ - padding: 0.1em 0.75em; +div.gradio-container{ + max-width: unset !important; } -#sh{ - min-width: 2em; - min-height: 2em; - max-width: 2em; - max-height: 2em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; - margin: 0.1em 0; - opacity: 0%; - cursor: default; +.hidden{ + display: none; } -.output-html p {margin: 0 0.5em;} - -.row > *, -.row > .gr-form > * { - min-width: min(120px, 100%); - flex: 1 1 0%; +.compact{ + background: transparent !important; + padding: 0 !important; } -.performance { - font-size: 0.85em; - color: #444; +div.form{ + border-width: 0; + box-shadow: none; + background: transparent; + overflow: visible; + gap: 0.5em; } -.performance p{ - display: inline-block; +.block.gradio-dropdown, +.block.gradio-slider, +.block.gradio-checkbox, +.block.gradio-textbox, +.block.gradio-radio, +.block.gradio-checkboxgroup, +.block.gradio-number +{ + border-width: 0 !important; + box-shadow: none !important; } -.performance .time { - margin-right: 0; +.gap.compact{ + padding: 0; + gap: 0; } -.performance .vram { +div.compact{ + gap: 0.5em; } -#txt2img_generate, #img2img_generate { - min-height: 4.5em; +.gradio-dropdown ul.options{ + max-height: 35em; } -@media screen and (min-width: 2500px) { - #txt2img_gallery, #img2img_gallery { - min-height: 768px; - } +.gradio-dropdown label span:not(.has-info){ + margin-bottom: 0; } -#txt2img_gallery img, #img2img_gallery img{ - object-fit: scale-down; -} -#txt2img_actions_column, #img2img_actions_column { - margin: 0.35rem 0.75rem 0.35rem 0; -} -#script_list { - padding: .625rem .75rem 0 .625rem; -} -.justify-center.overflow-x-scroll { - justify-content: left; +.gradio-dropdown div.wrap.wrap.wrap.wrap{ + box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05); } -.justify-center.overflow-x-scroll button:first-of-type { - margin-left: auto; +.gradio-slider input[type="number"]{ + width: 6em; } -.justify-center.overflow-x-scroll button:last-of-type { - margin-right: auto; +.block.gradio-checkbox { + margin: 0.75em 1.5em 0 0; } -[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{ - min-width: 2.3em; - height: 2.5em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; -} +/* general styled components */ -#hidden_element{ - display: none; +.gradio-button.tool{ + max-width: 2.2em; + min-width: 2.2em !important; + height: 2.4em; + align-self: end; + line-height: 1em; + border-radius: 0.5em; } -[id$=_seed_row], [id$=_subseed_row]{ - gap: 0.5rem; - padding: 0.6em; +.checkboxes-row{ + margin-bottom: 0.5em; + margin-left: 0em; } - -[id$=_subseed_show_box]{ +.checkboxes-row > div{ + flex: 0; + white-space: nowrap; min-width: auto; - flex-grow: 0; } -[id$=_subseed_show_box] > div{ - border: 0; - height: 100%; -} -[id$=_subseed_show]{ - min-width: auto; - flex-grow: 0; - padding: 0; -} +/* txt2img/img2img specific */ -[id$=_subseed_show] label{ - height: 100%; +.block.token-counter{ + position: absolute; + display: inline-block; + right: 1em; + min-width: 0 !important; + width: auto; + z-index: 100; } -#txt2img_actions_column, #img2img_actions_column{ - gap: 0; - margin-right: .75rem; +.block.token-counter span{ + background: var(--input-background-fill) !important; + box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075); + border: 2px solid rgba(192,192,192,0.4) !important; + border-radius: 0.4em; } -#txt2img_tools, #img2img_tools{ - gap: 0.4em; +.block.token-counter.error span{ + box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); + border: 2px solid rgba(255,0,0,0.4) !important; } -#interrogate_col{ - min-width: 0 !important; - max-width: 8em !important; - margin-right: 1em; - gap: 0; -} -#interrogate, #deepbooru{ - margin: 0em 0.25em 0.5em 0.25em; - min-width: 8em; - max-width: 8em; +.block.token-counter div{ + display: inline; } -#style_pos_col, #style_neg_col{ - min-width: 8em !important; +.block.token-counter span{ + padding: 0.1em 0.75em; } -#txt2img_styles_row, #img2img_styles_row{ - gap: 0.25em; - margin-top: 0.3em; +[id$=_subseed_show]{ + min-width: auto !important; + flex-grow: 0 !important; + display: flex; } -#txt2img_styles_row > button, #img2img_styles_row > button{ - margin: 0; +[id$=_subseed_show] label{ + margin-bottom: 0.5em; + align-self: end; } -#txt2img_styles, #img2img_styles{ - padding: 0; +.performance { + font-size: 0.85em; + color: #444; } -#txt2img_styles > label > div, #img2img_styles > label > div{ - min-height: 3.2em; +.performance p{ + display: inline-block; } -ul.list-none{ - max-height: 35em; - z-index: 2000; +.performance .time { + margin-right: 0; } -.gr-form{ - background: transparent; +.performance .vram { } -.my-4{ - margin-top: 0; - margin-bottom: 0; +#txt2img_generate, #img2img_generate { + min-height: 4.5em; } -#resize_mode{ - flex: 1.5; +@media screen and (min-width: 2500px) { + #txt2img_gallery, #img2img_gallery { + min-height: 768px; + } } -button{ - align-self: stretch !important; +#txt2img_gallery img, #img2img_gallery img{ + object-fit: scale-down; } - -.overflow-hidden, .gr-panel{ - overflow: visible !important; +#txt2img_actions_column, #img2img_actions_column { + gap: 0.5em; +} +#txt2img_tools, #img2img_tools{ + gap: 0.4em; } -#x_type, #y_type{ - max-width: 10em; +.interrogate-col{ + min-width: 0 !important; + max-width: fit-content; + gap: 0.5em; +} +.interrogate-col > button{ + min-width: 8em; + max-width: 8em; + height: 5.45em; } -#txt2img_preview, #img2img_preview, #ti_preview{ +.generate-box{ + position: relative; +} +.gradio-button.generate-box-skip, .gradio-button.generate-box-interrupt{ position: absolute; - width: 320px; + width: 50%; + height: 100%; + display: none; +} +.gradio-button.generate-box-interrupt{ left: 0; + border-radius: 0.5rem 0 0 0.5rem; +} +.gradio-button.generate-box-skip{ right: 0; - margin-left: auto; - margin-right: auto; - margin-top: 34px; - z-index: 100; - border: none; - border-top-left-radius: 0; - border-top-right-radius: 0; + border-radius: 0 0.5rem 0.5rem 0; } -@media screen and (min-width: 768px) { - #txt2img_preview, #img2img_preview, #ti_preview { - position: absolute; - } +#txtimg_hr_finalres{ + min-height: 0 !important; + padding: .625rem .75rem; + margin-left: -0.75em } -@media screen and (max-width: 767px) { - #txt2img_preview, #img2img_preview, #ti_preview { - position: relative; - } +#txtimg_hr_finalres .resolution{ + font-weight: bold; } -#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{ - display: none; +.inactive{ + opacity: 0.5; } -fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ - position: absolute; - top: -0.7em; - line-height: 1.2em; - padding: 0; - margin: 0 0.5em; - - background-color: white; - box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white; +[id$=_column_batch]{ + min-width: min(13.5em, 100%) !important; +} - z-index: 300; +[id$=_dimensions_row]{ + min-width: 0 !important; + max-width: fit-content; + padding: 0 1em; } -.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ - background-color: rgb(31, 41, 55); - box-shadow: none; - border: 1px solid rgba(128, 128, 128, 0.1); - border-radius: 6px; - padding: 0.1em 0.5em; +#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{ + height: 480px !important; + max-height: 480px !important; + min-height: 480px !important; } -#txt2img_column_batch, #img2img_column_batch{ - min-width: min(13.5em, 100%) !important; + +/* settings */ +#quicksettings { + width: fit-content; } -#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ - position: relative; +#quicksettings > div, #quicksettings > fieldset{ + max-width: 24em; + min-width: 24em; + padding: 0; border: none; - margin-right: 8em; + box-shadow: none; + background: none; + margin-right: 10px; } -#settings .gr-panel div.flex-col div.justify-between div{ - position: relative; - z-index: 200; +#quicksettings .gradio-dropdown .wrap-inner{ + flex-wrap: unset; +} + +#quicksettings .gradio-dropdown .single-select{ + white-space: nowrap; + overflow: hidden; } #settings{ @@ -276,14 +265,14 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s margin-left: 10em; } -#settings > div.flex-wrap{ +#settings > div.tab-nav{ float: left; display: block; margin-left: 0; width: 10em; } -#settings > div.flex-wrap button{ +#settings > div.tab-nav button{ display: block; border: none; text-align: left; @@ -294,29 +283,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s margin: 0 1.2em; } -input[type="range"]{ - margin: 0.5em 0 -0.3em 0; -} - -#mask_bug_info { - text-align: center; - display: block; - margin-top: -0.75em; - margin-bottom: -0.75em; -} - -#txt2img_negative_prompt, #img2img_negative_prompt{ -} - -/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */ -.transition.opacity-20 { - opacity: 1 !important; -} - -/* more gradio's garbage cleanup */ -.min-h-\[4rem\] { min-height: unset !important; } -.min-h-\[6rem\] { min-height: unset !important; } +/* live preview */ .progressDiv{ position: relative; height: 20px; @@ -362,6 +330,8 @@ input[type="range"]{ height: 100%; } +/* fullscreen popup (ie in Lora's (i) button) */ + .popup-metadata{ color: black; background: white; @@ -402,6 +372,8 @@ input[type="range"]{ padding: 2em; } +/* fullpage image viewer */ + #lightboxModal{ display: none; position: fixed; @@ -512,45 +484,7 @@ input[type="range"]{ background-color: rgba(0, 0, 0, 0.8); } -#imageARPreview{ - position:absolute; - top:0px; - left:0px; - border:2px solid red; - background:rgba(255, 0, 0, 0.3); - z-index: 900; - pointer-events:none; - display:none -} - -#txt2img_generate_box, #img2img_generate_box{ - position: relative; -} - -#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{ - position: absolute; - width: 50%; - height: 100%; - background: #b4c0cc; - display: none; -} - -#txt2img_interrupt, #img2img_interrupt{ - left: 0; - border-radius: 0.5rem 0 0 0.5rem; -} -#txt2img_skip, #img2img_skip{ - right: 0; - border-radius: 0 0.5rem 0.5rem 0; -} - -.red { - color: red; -} - -.gallery-item { - --tw-bg-opacity: 0 !important; -} +/* context menu (ie for the generate button) */ #context-menu{ z-index:9999; @@ -579,61 +513,8 @@ input[type="range"]{ background: #a55000; } -#quicksettings { - width: fit-content; -} -#quicksettings > div, #quicksettings > fieldset{ - max-width: 24em; - min-width: 24em; - padding: 0; - border: none; - box-shadow: none; - background: none; - margin-right: 10px; -} - -#quicksettings > div > div > div > label > span { - position: relative; - margin-right: 9em; - margin-bottom: -1em; -} - -canvas[key="mask"] { - z-index: 12 !important; - filter: invert(); - mix-blend-mode: multiply; - pointer-events: none; -} - - -/* gradio 3.4.1 stuff for editable scrollbar values */ -.gr-box > div > div > input.gr-text-input{ - position: absolute; - right: 0.5em; - top: -0.6em; - z-index: 400; - width: 6em; -} -#quicksettings .gr-box > div > div > input.gr-text-input { - top: -1.12em; -} - -.row.gr-compact{ - overflow: visible; -} - -#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, -#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img, -#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img, -#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img -{ - height: 480px !important; - max-height: 480px !important; - min-height: 480px !important; -} - -/* Extensions */ +/* extensions */ #tab_extensions table{ border-collapse: collapse; @@ -646,6 +527,7 @@ canvas[key="mask"] { #tab_extensions table input[type="checkbox"]{ margin-right: 0.5em; + appearance: checkbox; } #tab_extensions button{ @@ -670,74 +552,7 @@ canvas[key="mask"] { font-size: 90%; } -#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{ - min-width: auto; - padding-left: 0.5em; - padding-right: 0.5em; -} - -.gr-form{ - background-color: white; -} - -.dark .gr-form{ - background-color: rgb(31 41 55 / var(--tw-bg-opacity)); -} - -.gr-button-tool, .gr-button-tool-top{ - max-width: 2.5em; - min-width: 2.5em !important; - height: 2.4em; -} - -.gr-button-tool{ - margin: 0.6em 0em 0.55em 0; -} - -.gr-button-tool-top, #settings .gr-button-tool{ - margin: 1.6em 0.7em 0.55em 0; -} - - -#modelmerger_results_container{ - margin-top: 1em; - overflow: visible; -} - -#modelmerger_models{ - gap: 0; -} - - -#quicksettings .gr-button-tool{ - margin: 0; - border-color: unset; - background-color: unset; -} - -#modelmerger_interp_description>p { - margin: 0!important; - text-align: center; -} -#modelmerger_interp_description { - margin: 0.35rem 0.75rem 1.23rem; -} -#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { - padding-top: 0.9em; - padding-bottom: 0.9em; -} -#txt2img_settings { - padding-top: 1.16em; - padding-bottom: 0.9em; -} -#img2img_settings { - padding-bottom: 0.9em; -} - -#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{ - border: none; - padding-bottom: 0.5em; -} +/* replace original footer with ours */ footer { display: none !important; @@ -756,90 +571,7 @@ footer { opacity: 0.85; } -#txtimg_hr_finalres{ - min-height: 0 !important; - padding: .625rem .75rem; - margin-left: -0.75em - -} - -#txtimg_hr_finalres .resolution{ - font-weight: bold; -} - -#txt2img_checkboxes, #img2img_checkboxes{ - margin-bottom: 0.5em; - margin-left: 0em; -} -#txt2img_checkboxes > div, #img2img_checkboxes > div{ - flex: 0; - white-space: nowrap; - min-width: auto; -} - -#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ - margin-left: 0em; -} - -#axis_options { - margin-left: 0em; -} - -.inactive{ - opacity: 0.5; -} - -[id*='_prompt_container']{ - gap: 0; -} - -[id*='_prompt_container'] > div{ - margin: -0.4em 0 0 0; -} - -.gr-compact { - border: none; -} - -.dark .gr-compact{ - background-color: rgb(31 41 55 / var(--tw-bg-opacity)); - margin-left: 0; -} - -.gr-compact{ - overflow: visible; -} - -.gr-compact > *{ -} - -.gr-compact .gr-block, .gr-compact .gr-form{ - border: none; - box-shadow: none; -} - -.gr-compact .gr-box{ - border-radius: .5rem !important; - border-width: 1px !important; -} - -#mode_img2img > div > div{ - gap: 0 !important; -} - -[id*='img2img_copy_to_'] { - border: none; -} - -[id*='img2img_copy_to_'] > button { -} - -[id*='img2img_label_copy_to_'] { - font-size: 1.0em; - font-weight: bold; - text-align: center; - line-height: 2.4em; -} +/* extra networks UI */ .extra-networks > div > [id *= '_extra_']{ margin: 0.3em; @@ -1025,7 +757,3 @@ footer { .extra-network-cards .card ul a:hover{ color: red; } - -[id*='_prompt_container'] > div { - margin: 0!important; -} -- cgit v1.2.3 From 80b26d2a69617b75d2d01c1e6b7d11445815ed4d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 25 Mar 2023 23:06:33 +0300 Subject: apply Lora by altering layer's weights instead of adding more calculations in forward() --- extensions-builtin/Lora/lora.py | 72 ++++++++++++++++++++------ extensions-builtin/Lora/scripts/lora_script.py | 12 ++++- 2 files changed, 66 insertions(+), 18 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7c371deb..a737fec3 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -131,7 +131,7 @@ def load_lora(name, filename): with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,29 +177,69 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - input = devices.cond_cast_unet(input) - if len(loaded_loras) == 0: - return res +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): + """ + Applies the currently selected set of Loras to the weight of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + self.weight.copy_(weights_backup) + + lora_layer_name = getattr(self, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue - return res + with torch.no_grad(): + up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) + down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + setattr(self, "lora_current_names", wanted_names) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160..dc329e81 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora def before_ui(): @@ -20,11 +22,19 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'): torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), - "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), - })) -- cgit v1.2.3 From 650ddc9dd3c1d126221682be8270f7fba1b5b6ce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 26 Mar 2023 10:44:20 +0300 Subject: Lora support for SD2 --- extensions-builtin/Lora/lora.py | 155 ++++++++++++++++++------- extensions-builtin/Lora/scripts/lora_script.py | 10 ++ 2 files changed, 126 insertions(+), 39 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index d4345ada..edd95f78 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} re_digits = re.compile(r"\d+") -re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") -re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") -re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") -re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + r = re.match(regex, key) if not r: return False @@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = [] - if match(m, re_unet_down_blocks): - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - if match(m, re_unet_mid_blocks): - return f"diffusion_model_middle_block_1_{m[1]}" + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - if match(m, re_unet_up_blocks): - return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - if match(m, re_text_block): + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): if is_sd2: if 'mlp_fc1' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" @@ -109,16 +131,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) - keys_failed_to_match = [] + keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) - key, lora_key = fullkey.split(".", 1) + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: - keys_failed_to_match.append(key_diffusers) + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + + if sd_module is None: + keys_failed_to_match[key_diffusers] = key continue lora_module = lora.modules.get(key, None) @@ -133,7 +161,9 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: @@ -190,54 +220,94 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): +def lora_calc_updown(lora, module, target): + with torch.no_grad(): + up = module.up.weight.to(target.device, dtype=target.dtype) + down = module.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return updown + + +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention): """ - Applies the currently selected set of Loras to the weight of torch layer self. + Applies the currently selected set of Loras to the weights of torch layer self. If weights already have this particular set of loras applied, does nothing. If not, restores orginal weights from backup and alters weights according to loras. """ + lora_layer_name = getattr(self, 'lora_layer_name', None) + if lora_layer_name is None: + return + current_names = getattr(self, "lora_current_names", ()) wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) weights_backup = getattr(self, "lora_weights_backup", None) if weights_backup is None: - weights_backup = self.weight.to(devices.cpu, copy=True) + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup if current_names != wanted_names: if weights_backup is not None: - self.weight.copy_(weights_backup) + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) - lora_layer_name = getattr(self, 'lora_layer_name', None) for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) - if module is None: + if module is not None and hasattr(self, 'weight'): + self.weight += lora_calc_updown(lora, module, self.weight) continue - with torch.no_grad(): - up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) - down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - else: - updown = up @ down + self.in_proj_weight += updown_qkv + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) + continue + + if module is None: + continue - self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + print(f'failed to calculate lora weights for layer {lora_layer_name}') setattr(self, "lora_current_names", wanted_names) +def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + def lora_Linear_forward(self, input): lora_apply_weights(self) return torch.nn.Linear_forward_before_lora(self, input) -def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): - setattr(self, "lora_current_names", ()) - setattr(self, "lora_weights_backup", None) +def lora_Linear_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) @@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input): return torch.nn.Conv2d_forward_before_lora(self, input) -def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): - setattr(self, "lora_current_names", ()) - setattr(self, "lora_weights_backup", None) +def lora_Conv2d_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) -def lora_NonDynamicallyQuantizableLinear_forward(self, input): - return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) +def lora_MultiheadAttention_forward(self, *args, **kwargs): + lora_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index dc329e81..0adab225 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -12,6 +12,8 @@ def unload(): torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora def before_ui(): @@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) -- cgit v1.2.3 From 9d7390d2d19a8baf04ee4ebe598b96ac6ba7f97e Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Mon, 27 Mar 2023 04:28:40 +0300 Subject: convert to python v3.9 --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index edd95f78..79d11e0e 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -2,6 +2,7 @@ import glob import os import re import torch +from typing import Union from modules import shared, devices, sd_models, errors @@ -235,7 +236,7 @@ def lora_calc_updown(lora, module, target): return updown -def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention): +def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): """ Applies the currently selected set of Loras to the weights of torch layer self. If weights already have this particular set of loras applied, does nothing. -- cgit v1.2.3 From 6a147db1287fe660e1bfb2ebf5b3fadc14835c69 Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Mon, 27 Mar 2023 04:40:31 +0300 Subject: convert to python v3.9 --- extensions-builtin/Lora/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 79d11e0e..696be8ea 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -296,7 +296,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu setattr(self, "lora_current_names", wanted_names) -def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear): +def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): setattr(self, "lora_current_names", ()) setattr(self, "lora_weights_backup", None) -- cgit v1.2.3